diff --git a/checkpoint_engine/device_utils.py b/checkpoint_engine/device_utils.py index be5b81f..3173f2e 100644 --- a/checkpoint_engine/device_utils.py +++ b/checkpoint_engine/device_utils.py @@ -25,20 +25,49 @@ def get_ip() -> str: return socket.gethostbyname(socket.gethostname()) +def _get_npu_visible_physical_ids() -> list[int]: + visible_devices = os.getenv("ASCEND_RT_VISIBLE_DEVICES") + if not visible_devices: + return [] + npu_ids = [] + for device_id in visible_devices.split(","): + device_id = device_id.strip() + if device_id.isdigit(): + npu_ids.append(int(device_id)) + return npu_ids + + +def _get_npu_ids_to_scan() -> range | list[int]: + visible_physical_ids = _get_npu_visible_physical_ids() + if visible_physical_ids: + return visible_physical_ids + + npu = getattr(torch, "npu", None) + device_count = getattr(npu, "device_count", None) + if callable(device_count): + count = int(device_count()) + if count > 0: + return range(max(8, count)) + return range(8) + + def npu_generate_uuid() -> str: str_pid = str(os.getpid()) - npu_num = 8 try: - for npu_id in range(npu_num): + for npu_id in _get_npu_ids_to_scan(): cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)] result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603 str_result = str(result.stdout) if str_pid in str_result: # In A3 server, one NPU has two chips. match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result) + if match_chip_count is None: + raise ValueError(f"Failed to parse NPU chip count for npu_id {npu_id}") chip_count = int(match_chip_count.group(1)) search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :] match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid) + if match_chip_id is None: + raise ValueError(f"Failed to parse NPU chip id for npu_id {npu_id}") chip_id = int(match_chip_id.group(1)) return f"{get_ip()}-{npu_id * chip_count + chip_id}" raise ValueError("The current process is not running on the npu device") diff --git a/checkpoint_engine/distributed/vllm_hccl.py b/checkpoint_engine/distributed/vllm_hccl.py index fbdab0c..736e769 100644 --- a/checkpoint_engine/distributed/vllm_hccl.py +++ b/checkpoint_engine/distributed/vllm_hccl.py @@ -34,7 +34,7 @@ class HcclCommConfig(ctypes.Structure): ("hccl_op_expansion_mode", ctypes.c_uint32), ("hccl_rdma_traffic_class", ctypes.c_uint32), ("hccl_rdma_service_level", ctypes.c_uint32), - ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_world_rank_id", ctypes.c_uint32), ("hccl_job_id", ctypes.c_uint64), ("comm_engine", ctypes.c_int32), ("thread_num", ctypes.c_uint32), @@ -167,7 +167,7 @@ def create_subcomm(self, ranks: list[int]) -> hcclComm_t: hccl_deterministic=0xFFFFFFFF, hccl_comm_name=b"\0", hccl_udi=b"\0", - hccl_op_expansize_mode=0, + hccl_op_expansion_mode=0, hccl_rdma_traffic_class=0xFFFFFFFF, hccl_rdma_service_level=0xFFFFFFFF, hccl_world_rank_id=0, diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 1d8c5cf..d412fba 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -32,6 +32,15 @@ from checkpoint_engine.data_types import T +_CUDA_HOST_REGISTER_DEFAULT = 0x00 +_CUDA_HOST_REGISTER_MAPPED = 0x02 +_MANUAL_PIN_MEMORY_FLAGS = (_CUDA_HOST_REGISTER_DEFAULT, _CUDA_HOST_REGISTER_MAPPED) + + +def _is_valid_manual_pin_memory_flag(flag: int) -> bool: + return flag in _MANUAL_PIN_MEMORY_FLAGS + + def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]: ret = [] for meta in metas: @@ -401,8 +410,9 @@ def _unpin(t: torch.Tensor): # cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */ # cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */ # cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */ - assert p_flags.value == 0x02, ( - f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}" + assert _is_valid_manual_pin_memory_flag(p_flags.value), ( + "pin memory flag error, expected: 0x00 (cudaHostRegisterDefault) " + f"or 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}" ) cudart = torch.cuda.cudart() r = cudart.cudaHostUnregister(t.data_ptr()) diff --git a/tests/test_cuda_pin_memory.py b/tests/test_cuda_pin_memory.py new file mode 100644 index 0000000..97a21c1 --- /dev/null +++ b/tests/test_cuda_pin_memory.py @@ -0,0 +1,7 @@ +from checkpoint_engine.ps import _is_valid_manual_pin_memory_flag + + +def test_manual_pin_memory_flags_accept_default_and_mapped() -> None: + assert _is_valid_manual_pin_memory_flag(0x00) + assert _is_valid_manual_pin_memory_flag(0x02) + assert not _is_valid_manual_pin_memory_flag(0x01) diff --git a/tests/test_hccl_config.py b/tests/test_hccl_config.py new file mode 100644 index 0000000..fc145b1 --- /dev/null +++ b/tests/test_hccl_config.py @@ -0,0 +1,77 @@ +import ctypes +import importlib +import sys +import types +from typing import ClassVar + +import pytest +import torch + + +def test_hccl_comm_config_field_names_are_assignable(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeFunction: + def __init__(self, *args: object) -> None: + self.args = args + + class FakeHcclLibrary: + exported_functions: ClassVar[list[object]] = [] + + class FakePyHcclCommunicator: + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + class FakeHcclDataTypeEnum: + @staticmethod + def from_torch(dtype: torch.dtype) -> int: + return 0 + + class FakeNpu: + class Stream: + npu_stream = 0 + + modules = { + "vllm": types.ModuleType("vllm"), + "vllm.distributed": types.ModuleType("vllm.distributed"), + "vllm.distributed.utils": types.ModuleType("vllm.distributed.utils"), + "vllm_ascend": types.ModuleType("vllm_ascend"), + "vllm_ascend.distributed": types.ModuleType("vllm_ascend.distributed"), + "vllm_ascend.distributed.device_communicators": types.ModuleType( + "vllm_ascend.distributed.device_communicators" + ), + "vllm_ascend.distributed.device_communicators.pyhccl": types.ModuleType( + "vllm_ascend.distributed.device_communicators.pyhccl" + ), + "vllm_ascend.distributed.device_communicators.pyhccl_wrapper": types.ModuleType( + "vllm_ascend.distributed.device_communicators.pyhccl_wrapper" + ), + "vllm_ascend.utils": types.ModuleType("vllm_ascend.utils"), + } + modules["vllm.distributed.utils"].StatelessProcessGroup = object + modules[ + "vllm_ascend.distributed.device_communicators.pyhccl" + ].PyHcclCommunicator = FakePyHcclCommunicator + wrapper = modules["vllm_ascend.distributed.device_communicators.pyhccl_wrapper"] + wrapper.Function = FakeFunction + wrapper.HCCLLibrary = FakeHcclLibrary + wrapper.aclrtStream_t = ctypes.c_void_p + wrapper.buffer_type = ctypes.c_void_p + wrapper.hcclComm_t = ctypes.c_void_p + wrapper.hcclDataType_t = ctypes.c_int + wrapper.hcclDataTypeEnum = FakeHcclDataTypeEnum + wrapper.hcclResult_t = int + modules["vllm_ascend.utils"].current_stream = lambda: FakeNpu.Stream() + + for name, module in modules.items(): + monkeypatch.setitem(sys.modules, name, module) + monkeypatch.setattr(torch, "npu", FakeNpu(), raising=False) + monkeypatch.delitem(sys.modules, "checkpoint_engine.distributed.vllm_hccl", raising=False) + + module = importlib.import_module("checkpoint_engine.distributed.vllm_hccl") + field_names = [name for name, _ in module.HcclCommConfig._fields_] + + assert "hccl_op_expansion_mode" in field_names + assert "hccl_world_rank_id" in field_names + assert "hcll_world_rank_id" not in field_names + config = module.HcclCommConfig(hccl_op_expansion_mode=7, hccl_world_rank_id=3) + assert config.hccl_op_expansion_mode == 7 + assert config.hccl_world_rank_id == 3 diff --git a/tests/test_npu_device_utils.py b/tests/test_npu_device_utils.py new file mode 100644 index 0000000..dbf0e33 --- /dev/null +++ b/tests/test_npu_device_utils.py @@ -0,0 +1,72 @@ +import subprocess + +import pytest +import torch + +from checkpoint_engine.device_utils import npu_generate_uuid + + +def test_npu_generate_uuid_checks_reported_device_count(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeNpu: + @staticmethod + def device_count() -> int: + return 10 + + pid = 12345 + seen_npu_ids: list[int] = [] + + def fake_run( + cmd: list[str], + *, + check: bool, + capture_output: bool, + text: bool, + ) -> subprocess.CompletedProcess[str]: + npu_id = int(cmd[-1]) + seen_npu_ids.append(npu_id) + stdout = "Chip Count: 2\n" + if npu_id == 9: + stdout += f"PID: {pid}\nChip ID: 1\n" + return subprocess.CompletedProcess(cmd, 0, stdout=stdout, stderr="") + + monkeypatch.setattr(torch, "npu", FakeNpu(), raising=False) + monkeypatch.delenv("ASCEND_RT_VISIBLE_DEVICES", raising=False) + monkeypatch.setattr("checkpoint_engine.device_utils.os.getpid", lambda: pid) + monkeypatch.setattr("checkpoint_engine.device_utils.get_ip", lambda: "10.0.0.1") + monkeypatch.setattr("checkpoint_engine.device_utils.subprocess.run", fake_run) + + assert npu_generate_uuid() == "10.0.0.1-19" + assert seen_npu_ids == list(range(10)) + + +def test_npu_generate_uuid_scans_visible_physical_ids(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeNpu: + @staticmethod + def device_count() -> int: + return 1 + + pid = 12345 + seen_npu_ids: list[int] = [] + + def fake_run( + cmd: list[str], + *, + check: bool, + capture_output: bool, + text: bool, + ) -> subprocess.CompletedProcess[str]: + npu_id = int(cmd[-1]) + seen_npu_ids.append(npu_id) + stdout = "Chip Count: 2\n" + if npu_id == 4: + stdout += f"PID: {pid}\nChip ID: 1\n" + return subprocess.CompletedProcess(cmd, 0, stdout=stdout, stderr="") + + monkeypatch.setattr(torch, "npu", FakeNpu(), raising=False) + monkeypatch.setenv("ASCEND_RT_VISIBLE_DEVICES", "4") + monkeypatch.setattr("checkpoint_engine.device_utils.os.getpid", lambda: pid) + monkeypatch.setattr("checkpoint_engine.device_utils.get_ip", lambda: "10.0.0.1") + monkeypatch.setattr("checkpoint_engine.device_utils.subprocess.run", fake_run) + + assert npu_generate_uuid() == "10.0.0.1-9" + assert seen_npu_ids == [4]