Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 108 additions & 37 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,22 +241,23 @@ def pytest_configure(config):
# at the end, so the combination is now safe.


def pytest_collection_modifyitems(session, config, items):
def pytest_collection_modifyitems(session, config, items): # noqa: PLR0912
"""Skip ST tests based on --platform, --runtime, --level filters; order L3 before L2."""
platform = config.getoption("--platform")
runtime_filter = config.getoption("--runtime")
level_filter = config.getoption("--level")

# Orchestrator L3 children set PTO_TARGET_NODEID to the single case they
# were dispatched for. Pytest's --case filter runs inside test_run (too
# late — other classes' st_worker fixtures already fired at setup). Skip
# everything except the target nodeid so the child stays narrow even when
# the parent invocation had broad positional args like ``examples tests/st``.
target_nodeid = os.environ.get("PTO_TARGET_NODEID")
if target_nodeid:
# Non-SceneTestCase items that declare @pytest.mark.device_count are
# "resource tests" — they don't participate in level-based dispatch and
# run in their own Resource phase. When --level is active, skip them so
# they don't try to allocate devices from the level-filtered pool.
if level_filter is not None:
for item in items:
if item.nodeid != target_nodeid:
item.add_marker(pytest.mark.skip(reason=f"dispatcher target is {target_nodeid}"))
if any(m.name == "skip" for m in item.iter_markers()):
continue
cls = getattr(item, "cls", None)
if cls is None and item.get_closest_marker("device_count"):
item.add_marker(pytest.mark.skip(reason=f"resource test (device_count), not level {level_filter}"))

# Sort: L3 tests first (they fork child processes that inherit main process CANN state,
# so they must run before L2 tests pollute the CANN context).
Expand Down Expand Up @@ -341,25 +342,19 @@ def _collect_st_runtimes(items, level=None):


def _collect_l3_cases(items, platform):
"""Collect one job per L3 class (not per case).
"""Collect one job per L3 ``SceneTestCase`` class (not per case).

Returns a list of tuples ``(nodeid, cls_name, runtime, max_device_count)``
where ``max_device_count`` is the maximum ``device_count`` across the
class's matching cases. Per-class dispatch matches the ``st_worker``
fixture's contract (it allocates ``max(CASES.device_count)`` for the whole
class) — dispatching per-case with a smaller device budget would trip the
fixture whenever the class also has a case that needs more devices.

Cases within a class still run in the child process via the existing
``test_run`` case loop, reusing the Worker (layer-4 reuse).
class's matching cases.
"""
by_nodeid: dict[str, tuple[str, str, int]] = {}
for item in items:
if any(m.name == "skip" for m in item.iter_markers()):
continue
cls = getattr(item, "cls", None)
if not cls or getattr(cls, "_st_level", None) != 3:
continue
if any(m.name == "skip" for m in item.iter_markers()):
continue
rt = getattr(cls, "_st_runtime", None)
if not rt:
continue
Expand All @@ -369,14 +364,43 @@ def _collect_l3_cases(items, platform):
if platform and platform not in case.get("platforms", []):
continue
if case.get("manual"):
continue # --manual exclude is the default; children honor the flag
continue
saw_case = True
max_dev = max(max_dev, int(case.get("config", {}).get("device_count", 1)))
if saw_case:
by_nodeid[item.nodeid] = (cls.__name__, rt, max_dev)
return [(nodeid, cls_name, rt, dev) for nodeid, (cls_name, rt, dev) in by_nodeid.items()]


def _collect_resource_cases(items, platform):
"""Collect non-``SceneTestCase`` pytest functions that declare resource needs.

Returns a list of tuples ``(nodeid, func_name, runtime, device_count)``.
These run in their own dispatch phase — they don't participate in
level-based dispatch. A function must carry both
``@pytest.mark.device_count(n)`` and ``@pytest.mark.runtime("...")``.
"""
by_nodeid: dict[str, tuple[str, str, int]] = {}
for item in items:
if any(m.name == "skip" for m in item.iter_markers()):
continue
cls = getattr(item, "cls", None)
if cls is not None:
continue
dev_marker = item.get_closest_marker("device_count")
if dev_marker is None:
continue
rt_marker = item.get_closest_marker("runtime")
if rt_marker is None or not rt_marker.args:
continue
platforms_marker = item.get_closest_marker("platforms")
if platforms_marker and platform and platform not in platforms_marker.args[0]:
continue
dev_count = int(dev_marker.args[0]) if dev_marker.args else 1
by_nodeid[item.nodeid] = (item.name, rt_marker.args[0], dev_count)
return [(nodeid, label, rt, dev) for nodeid, (label, rt, dev) in by_nodeid.items()]


def _base_pytest_argv(session):
"""Inherit the user's original pytest invocation args."""
base = [sys.executable, "-m", "pytest"]
Expand All @@ -401,8 +425,8 @@ def _resolve_max_parallel(cfg, platform: str, device_ids: list[int]) -> int:
return val


def _dispatch_test_phases(session):
"""Run L3 phase (device-parallel) then L2 phase (per-runtime subprocess)."""
def _dispatch_test_phases(session): # noqa: PLR0912
"""Run L3 → Standalone → L2 phases."""
from simpler_setup import parallel_scheduler as _ps # noqa: PLC0415

cfg = session.config
Expand Down Expand Up @@ -436,20 +460,11 @@ def _build(ids, _nodeid=nodeid, _rt=rt):
_ps.format_device_range(ids),
]

# PTO_TARGET_NODEID makes the child skip every item except this
# nodeid — defends against inherited positional args (``examples``,
# ``tests/st``) collecting unrelated classes whose fixtures would
# then fire at setup and fail on the narrower child device pool.
# SIMPLER_PERF_OUTPUT_DIR scopes this L3 case's perf files to its own
# subdir so concurrent L3 cases can't collide on filename (the
# runtime's timestamp is second-precision). Anchor to cfg.rootpath
# so the C++ runtime and Python post-processing agree regardless
# of the child's CWD. Use a nodeid-derived sanitized label so the
# dir name stays readable for post-mortem.
# subdir so concurrent L3 cases can't collide on filename.
safe_nodeid = nodeid.replace("/", "_").replace(":", "_").replace(".", "_")
child_env = {
**os.environ,
"PTO_TARGET_NODEID": nodeid,
"SIMPLER_PERF_OUTPUT_DIR": str(cfg.rootpath / "outputs" / f"perf_l3_{safe_nodeid}"),
}
jobs.append(_ps.Job(label=label, device_count=dev_count, build_cmd=_build, cwd=str(cwd), env=child_env))
Expand Down Expand Up @@ -527,15 +542,69 @@ def _on_done(res):
else:
print(f"\n--- L2 runtime {rt}: PASSED ---\n", flush=True)

# ----- Phase 3: Resource (non-SceneTestCase functions with device_count) -----
resource_cases = _collect_resource_cases(session.items, platform)
resource_failed = False
if resource_cases:
jobs = []
for nodeid, func_name, rt, dev_count in resource_cases:
label = f"resource {func_name} (rt={rt}, dev={dev_count})"

def _build(ids, _nodeid=nodeid, _rt=rt):
return base_args + [
_nodeid,
"--runtime",
_rt,
"--device",
_ps.format_device_range(ids),
]

safe_nodeid = nodeid.replace("/", "_").replace(":", "_").replace(".", "_")
child_env = {
**os.environ,
"SIMPLER_PERF_OUTPUT_DIR": str(cfg.rootpath / "outputs" / f"perf_rc_{safe_nodeid}"),
}
jobs.append(_ps.Job(label=label, device_count=dev_count, build_cmd=_build, cwd=str(cwd), env=child_env))

def _on_rc_done(res):
tag = "PASSED" if res.returncode == 0 else f"FAILED (rc={res.returncode})"
print(f"\n--- {res.label}: {tag} on devices {res.device_ids} ---\n", flush=True)

print(
f"\n{'=' * 60}\n Resource phase: {len(jobs)} case(s), "
f"pool={device_ids}, max_parallel={max_parallel}\n{'=' * 60}\n",
flush=True,
)
try:
results = _ps.run_jobs(
jobs,
device_ids,
max_parallel=max_parallel,
fail_fast=fail_fast,
on_job_done=_on_rc_done,
)
except ValueError as e:
print(f"\n*** Resource phase ABORTED: {e} ***\n", flush=True)
session.testsfailed = 1
return True
resource_failed = any(r.returncode != 0 for r in results)
if any(r.returncode == TIMEOUT_EXIT_CODE for r in results):
print("\n*** Resource phase: TIMED OUT ***\n", flush=True)
os._exit(TIMEOUT_EXIT_CODE)

if resource_failed and fail_fast:
session.testsfailed = 1
return True

# Flatten per-subprocess outputs/perf_*/ subdirs back to outputs/ so
# downstream tools (swimlane_converter.py, CI artifact upload) find
# everything in the historical location. Anchor to config.rootpath (not
# invocation_params.dir) so a user running pytest from a subdirectory
# still flushes files into the project's top-level outputs/.
_ps.flatten_perf_subdirs(cfg.rootpath / "outputs")

session.testsfailed = 1 if (l3_failed or l2_failed) else 0
if not (l3_failed or l2_failed):
session.testsfailed = 1 if (l3_failed or l2_failed or resource_failed) else 0
if not (l3_failed or l2_failed or resource_failed):
session.testscollected = sum(1 for _ in session.items)
return True # returning True prevents default runtestloop

Expand All @@ -549,8 +618,10 @@ def pytest_runtestloop(session):
runtime_filter = session.config.getoption("--runtime")
level_filter = session.config.getoption("--level")

# Child mode: the dispatcher's spawned subprocesses carry both flags.
if runtime_filter is not None and level_filter is not None:
# Child mode: if the caller filters by runtime or level, it wants direct
# control — don't re-enter the multi-phase dispatcher (which would cause
# nested dispatch, device pool exhaustion, and timeout).
if runtime_filter is not None or level_filter is not None:
return

# User explicitly asked for collect-only / scoped-run — don't orchestrate.
Expand Down
15 changes: 10 additions & 5 deletions examples/workers/l2/hello_worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()


def main() -> int:
args = parse_args()
def run(platform: str, device_id: int) -> int:
"""Core logic — callable from both CLI and pytest."""

# Worker(level=2, ...) wraps a single C++ ChipWorker. Construction does NOT
# load any binaries or touch the device — it just stashes config. The heavy
# work happens in init().
worker = Worker(
level=2,
platform=args.platform,
platform=platform,
runtime="tensormap_and_ringbuffer",
device_id=args.device,
device_id=device_id,
)

# init() resolves ``build/lib/<platform>/tensormap_and_ringbuffer/*`` via
# RuntimeBuilder, dlopens host_runtime.so, loads aicpu.so + aicore.o, and
# calls aclrtSetDevice(device_id). If any of those fails this raises.
print(f"[hello_worker] init on {args.platform} device={args.device} ...")
print(f"[hello_worker] init on {platform} device={device_id} ...")
worker.init()

try:
Expand All @@ -90,5 +90,10 @@ def main() -> int:
return 0


def main() -> int:
args = parse_args()
return run(args.platform, args.device)


if __name__ == "__main__":
sys.exit(main())
24 changes: 24 additions & 0 deletions examples/workers/l2/hello_worker/test_hello_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Hardware ST for examples/workers/l2/hello_worker."""

import os
from importlib.machinery import SourceFileLoader

import pytest

_main = SourceFileLoader("hello_worker_main", os.path.join(os.path.dirname(__file__), "main.py")).load_module()
run = _main.run


@pytest.mark.platforms(["a2a3sim", "a2a3", "a5sim", "a5"])
@pytest.mark.runtime("tensormap_and_ringbuffer")
def test_hello_worker(st_platform, st_device_ids):
rc = run(st_platform, int(st_device_ids[0]))
assert rc == 0
Loading