Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions checkpoint_engine/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,49 @@ def get_ip() -> str:
return socket.gethostbyname(socket.gethostname())


def _get_npu_visible_physical_ids() -> list[int]:
visible_devices = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
if not visible_devices:
return []
npu_ids = []
for device_id in visible_devices.split(","):
device_id = device_id.strip()
if device_id.isdigit():
npu_ids.append(int(device_id))
return npu_ids


def _get_npu_ids_to_scan() -> range | list[int]:
visible_physical_ids = _get_npu_visible_physical_ids()
if visible_physical_ids:
return visible_physical_ids

npu = getattr(torch, "npu", None)
device_count = getattr(npu, "device_count", None)
if callable(device_count):
count = int(device_count())
if count > 0:
return range(max(8, count))
return range(8)


def npu_generate_uuid() -> str:
str_pid = str(os.getpid())
npu_num = 8
try:
for npu_id in range(npu_num):
for npu_id in _get_npu_ids_to_scan():
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
str_result = str(result.stdout)
if str_pid in str_result:
# In A3 server, one NPU has two chips.
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
if match_chip_count is None:
raise ValueError(f"Failed to parse NPU chip count for npu_id {npu_id}")
chip_count = int(match_chip_count.group(1))
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
if match_chip_id is None:
raise ValueError(f"Failed to parse NPU chip id for npu_id {npu_id}")
chip_id = int(match_chip_id.group(1))
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
raise ValueError("The current process is not running on the npu device")
Expand Down
4 changes: 2 additions & 2 deletions checkpoint_engine/distributed/vllm_hccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class HcclCommConfig(ctypes.Structure):
("hccl_op_expansion_mode", ctypes.c_uint32),
("hccl_rdma_traffic_class", ctypes.c_uint32),
("hccl_rdma_service_level", ctypes.c_uint32),
("hcll_world_rank_id", ctypes.c_uint32),
("hccl_world_rank_id", ctypes.c_uint32),
("hccl_job_id", ctypes.c_uint64),
("comm_engine", ctypes.c_int32),
("thread_num", ctypes.c_uint32),
Expand Down Expand Up @@ -167,7 +167,7 @@ def create_subcomm(self, ranks: list[int]) -> hcclComm_t:
hccl_deterministic=0xFFFFFFFF,
hccl_comm_name=b"\0",
hccl_udi=b"\0",
hccl_op_expansize_mode=0,
hccl_op_expansion_mode=0,
hccl_rdma_traffic_class=0xFFFFFFFF,
hccl_rdma_service_level=0xFFFFFFFF,
hccl_world_rank_id=0,
Expand Down
14 changes: 12 additions & 2 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@
from checkpoint_engine.data_types import T


_CUDA_HOST_REGISTER_DEFAULT = 0x00
_CUDA_HOST_REGISTER_MAPPED = 0x02
_MANUAL_PIN_MEMORY_FLAGS = (_CUDA_HOST_REGISTER_DEFAULT, _CUDA_HOST_REGISTER_MAPPED)


def _is_valid_manual_pin_memory_flag(flag: int) -> bool:
return flag in _MANUAL_PIN_MEMORY_FLAGS


def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
ret = []
for meta in metas:
Expand Down Expand Up @@ -401,8 +410,9 @@ def _unpin(t: torch.Tensor):
# cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
# cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
# cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
assert p_flags.value == 0x02, (
f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
assert _is_valid_manual_pin_memory_flag(p_flags.value), (
"pin memory flag error, expected: 0x00 (cudaHostRegisterDefault) "
f"or 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
)
cudart = torch.cuda.cudart()
r = cudart.cudaHostUnregister(t.data_ptr())
Expand Down
7 changes: 7 additions & 0 deletions tests/test_cuda_pin_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from checkpoint_engine.ps import _is_valid_manual_pin_memory_flag


def test_manual_pin_memory_flags_accept_default_and_mapped() -> None:
assert _is_valid_manual_pin_memory_flag(0x00)
assert _is_valid_manual_pin_memory_flag(0x02)
assert not _is_valid_manual_pin_memory_flag(0x01)
77 changes: 77 additions & 0 deletions tests/test_hccl_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import ctypes
import importlib
import sys
import types
from typing import ClassVar

import pytest
import torch


def test_hccl_comm_config_field_names_are_assignable(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeFunction:
def __init__(self, *args: object) -> None:
self.args = args

class FakeHcclLibrary:
exported_functions: ClassVar[list[object]] = []

class FakePyHcclCommunicator:
def __init__(self, *args: object, **kwargs: object) -> None:
pass

class FakeHcclDataTypeEnum:
@staticmethod
def from_torch(dtype: torch.dtype) -> int:
return 0

class FakeNpu:
class Stream:
npu_stream = 0

modules = {
"vllm": types.ModuleType("vllm"),
"vllm.distributed": types.ModuleType("vllm.distributed"),
"vllm.distributed.utils": types.ModuleType("vllm.distributed.utils"),
"vllm_ascend": types.ModuleType("vllm_ascend"),
"vllm_ascend.distributed": types.ModuleType("vllm_ascend.distributed"),
"vllm_ascend.distributed.device_communicators": types.ModuleType(
"vllm_ascend.distributed.device_communicators"
),
"vllm_ascend.distributed.device_communicators.pyhccl": types.ModuleType(
"vllm_ascend.distributed.device_communicators.pyhccl"
),
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper": types.ModuleType(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper"
),
"vllm_ascend.utils": types.ModuleType("vllm_ascend.utils"),
}
modules["vllm.distributed.utils"].StatelessProcessGroup = object
modules[
"vllm_ascend.distributed.device_communicators.pyhccl"
].PyHcclCommunicator = FakePyHcclCommunicator
wrapper = modules["vllm_ascend.distributed.device_communicators.pyhccl_wrapper"]
wrapper.Function = FakeFunction
wrapper.HCCLLibrary = FakeHcclLibrary
wrapper.aclrtStream_t = ctypes.c_void_p
wrapper.buffer_type = ctypes.c_void_p
wrapper.hcclComm_t = ctypes.c_void_p
wrapper.hcclDataType_t = ctypes.c_int
wrapper.hcclDataTypeEnum = FakeHcclDataTypeEnum
wrapper.hcclResult_t = int
modules["vllm_ascend.utils"].current_stream = lambda: FakeNpu.Stream()

for name, module in modules.items():
monkeypatch.setitem(sys.modules, name, module)
monkeypatch.setattr(torch, "npu", FakeNpu(), raising=False)
monkeypatch.delitem(sys.modules, "checkpoint_engine.distributed.vllm_hccl", raising=False)

module = importlib.import_module("checkpoint_engine.distributed.vllm_hccl")
field_names = [name for name, _ in module.HcclCommConfig._fields_]

assert "hccl_op_expansion_mode" in field_names
assert "hccl_world_rank_id" in field_names
assert "hcll_world_rank_id" not in field_names
config = module.HcclCommConfig(hccl_op_expansion_mode=7, hccl_world_rank_id=3)
assert config.hccl_op_expansion_mode == 7
assert config.hccl_world_rank_id == 3
72 changes: 72 additions & 0 deletions tests/test_npu_device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import subprocess

import pytest
import torch

from checkpoint_engine.device_utils import npu_generate_uuid


def test_npu_generate_uuid_checks_reported_device_count(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeNpu:
@staticmethod
def device_count() -> int:
return 10

pid = 12345
seen_npu_ids: list[int] = []

def fake_run(
cmd: list[str],
*,
check: bool,
capture_output: bool,
text: bool,
) -> subprocess.CompletedProcess[str]:
npu_id = int(cmd[-1])
seen_npu_ids.append(npu_id)
stdout = "Chip Count: 2\n"
if npu_id == 9:
stdout += f"PID: {pid}\nChip ID: 1\n"
return subprocess.CompletedProcess(cmd, 0, stdout=stdout, stderr="")

monkeypatch.setattr(torch, "npu", FakeNpu(), raising=False)
monkeypatch.delenv("ASCEND_RT_VISIBLE_DEVICES", raising=False)
monkeypatch.setattr("checkpoint_engine.device_utils.os.getpid", lambda: pid)
monkeypatch.setattr("checkpoint_engine.device_utils.get_ip", lambda: "10.0.0.1")
monkeypatch.setattr("checkpoint_engine.device_utils.subprocess.run", fake_run)

assert npu_generate_uuid() == "10.0.0.1-19"
assert seen_npu_ids == list(range(10))


def test_npu_generate_uuid_scans_visible_physical_ids(monkeypatch: pytest.MonkeyPatch) -> None:
class FakeNpu:
@staticmethod
def device_count() -> int:
return 1

pid = 12345
seen_npu_ids: list[int] = []

def fake_run(
cmd: list[str],
*,
check: bool,
capture_output: bool,
text: bool,
) -> subprocess.CompletedProcess[str]:
npu_id = int(cmd[-1])
seen_npu_ids.append(npu_id)
stdout = "Chip Count: 2\n"
if npu_id == 4:
stdout += f"PID: {pid}\nChip ID: 1\n"
return subprocess.CompletedProcess(cmd, 0, stdout=stdout, stderr="")

monkeypatch.setattr(torch, "npu", FakeNpu(), raising=False)
monkeypatch.setenv("ASCEND_RT_VISIBLE_DEVICES", "4")
monkeypatch.setattr("checkpoint_engine.device_utils.os.getpid", lambda: pid)
monkeypatch.setattr("checkpoint_engine.device_utils.get_ip", lambda: "10.0.0.1")
monkeypatch.setattr("checkpoint_engine.device_utils.subprocess.run", fake_run)

assert npu_generate_uuid() == "10.0.0.1-9"
assert seen_npu_ids == [4]
Loading