Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError as e:
dsa2 = None

SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa']
SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa', 'xla']

ds_accelerator = None

Expand Down Expand Up @@ -98,6 +98,12 @@ def get_accelerator():
import torch_mlu # noqa: F401
except ImportError as e:
raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.")
elif accelerator_name in ["xla", "tpu"]:
accelerator_name = "xla"
try:
import torch_xla # noqa: F401
except ImportError as e:
raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.")
elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST:
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
f'Value "{accelerator_name}" is not supported')
Expand Down Expand Up @@ -125,6 +131,14 @@ def get_accelerator():
accelerator_name = "xpu"
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch_xla.core.xla_model as xm

if len(xm.get_xla_supported_devices(devkind='TPU')) > 0:
accelerator_name = "xla"
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore
Expand Down Expand Up @@ -220,6 +234,10 @@ def get_accelerator():
from .mlu_accelerator import MLU_Accelerator

ds_accelerator = MLU_Accelerator()
elif accelerator_name == 'xla':
from .xla_accelerator import XLA_Accelerator

ds_accelerator = XLA_Accelerator()
_validate_accelerator(ds_accelerator)
if accel_logger is not None:
accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
Expand Down
265 changes: 265 additions & 0 deletions accelerator/xla_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import functools

import torch

from .abstract_accelerator import DeepSpeedAccelerator

try:
import torch_xla.core.xla_model as xm
except ImportError as e:
xm = None


class XLA_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'xla'
self._communication_backend_name = 'xla'
self._compile_backend = None
if xm is None:
raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.")

def _require_xm(self):
if xm is None:
raise RuntimeError("torch_xla is required to use the XLA_Accelerator")
return xm

def _tensor_factory(self, dtype):
return functools.partial(torch.tensor, dtype=dtype, device=self.current_device_name())

def is_synchronized_device(self):
return True

def use_host_timers(self):
return True

def resolves_data_dependency(self):
return True

def handles_memory_backpressure(self):
return True

# Device APIs
def device_name(self, device_index=None):
if device_index is None:
return 'xla'
return f'xla:{device_index}'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Map XLA device names to addressable device indices

device_name(device_index) is used by DeepSpeedEngine._set_distributed_vars() (deepspeed/runtime/engine.py:1287-1290) to build self.device from LOCAL_RANK, and the module is then moved there at line 1427. Under PJRT, LOCAL_RANK is not the per-process XLA device index: each worker only sees its own addressable XLA devices, which is often just xla:0. Returning xla:{LOCAL_RANK} here means nonzero ranks on a multi-chip TPU will try to place the model on xla:1, xla:2, etc., which are not addressable in that process and causes multi-rank launches to fail or target the wrong chip.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7f82c20. device_name(device_index) now resolves through the process addressable XLA device list instead of treating LOCAL_RANK as a raw XLA ordinal, so single-device-per-process workers map nonzero local ranks back to their local xla:0 device correctly.


def device(self, device_index=None):
xm_module = self._require_xm()
return xm_module.xla_device(n=device_index, devkind='TPU')

def set_device(self, device_index):
os.environ['LOCAL_RANK'] = str(device_index)
Comment on lines +81 to +84
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Actually bind each TPU worker in set_device()

DeepSpeed's launcher gives every local process the same TPU visibility mask (deepspeed/launcher/launch.py:182-183) and relies on get_accelerator().set_device(local_rank) from DeepSpeedEngine._set_distributed_vars() to pin each worker to its chip. This implementation only rewrites LOCAL_RANK; it never calls a torch_xla/PJRT device-selection API or sets the PJRT process-rank env that torch_xla uses to derive local ordinals. On a host with multiple TPU chips, multiple ranks can therefore attach to the same default XLA device, which breaks distributed initialization and ZeRO synchronization.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7f82c20. set_device() now calls into xm.xla_device() to select the XLA default device for the current process before DeepSpeed moves the model, and it preserves the launcher-provided rank information in the environment.


def current_device(self):
xm_module = self._require_xm()
return xm_module.get_local_ordinal()

def current_device_name(self):
return self.device_name(self.current_device())

def device_count(self):
xm_module = self._require_xm()
return len(xm_module.get_xla_supported_devices(devkind='TPU'))

def synchronize(self, device_index=None):
xm_module = self._require_xm()
xm_module.mark_step()
return xm_module.wait_device_ops()

# RNG APIs
def random(self):
return torch.random

def set_rng_state(self, new_state, device_index=None):
return torch.set_rng_state(new_state)

def get_rng_state(self, device_index=None):
return torch.get_rng_state()

def manual_seed(self, seed):
return torch.manual_seed(seed)

def manual_seed_all(self, seed):
return torch.manual_seed(seed)

def initial_seed(self):
return torch.initial_seed()

def default_generator(self, device_index):
return torch.default_generator

# Streams/Events
@property
def Stream(self):
return None

def stream(self, stream):
from deepspeed.runtime.utils import noop_context
return noop_context()

def current_stream(self, device_index=None):
return None

def default_stream(self, device_index=None):
return None

@property
def Event(self):
return None

# Memory management
def empty_cache(self):
return

def memory_allocated(self, device_index=None):
return 0

def max_memory_allocated(self, device_index=None):
return 0

def reset_max_memory_allocated(self, device_index=None):
return

def memory_cached(self, device_index=None):
return 0

def max_memory_cached(self, device_index=None):
return 0

def reset_max_memory_cached(self, device_index=None):
return

def memory_stats(self, device_index=None):
return {}

def reset_peak_memory_stats(self, device_index=None):
return

def memory_reserved(self, device_index=None):
return 0

def max_memory_reserved(self, device_index=None):
return 0

def total_memory(self, device_index=None):
return 0

def available_memory(self, device_index=None):
return 0

# Data types
def is_bf16_supported(self):
return True

def is_fp16_supported(self):
return False

def supported_dtypes(self):
return [torch.float32, torch.bfloat16]

# Misc
def is_available(self):
return self.device_count() > 0

def range_push(self, msg):
return

def range_pop(self):
return

def lazy_call(self, callback):
return callback()

def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
return self._tensor_factory(torch.bfloat16)

@property
def ByteTensor(self):
return self._tensor_factory(torch.uint8)

@property
def DoubleTensor(self):
return self._tensor_factory(torch.float64)

@property
def FloatTensor(self):
return self._tensor_factory(torch.float32)

@property
def HalfTensor(self):
return self._tensor_factory(torch.float16)

@property
def IntTensor(self):
return self._tensor_factory(torch.int32)

@property
def LongTensor(self):
return self._tensor_factory(torch.int64)

def pin_memory(self, tensor, align_bytes=1):
return tensor

def is_pinned(self, tensor):
return False

def on_accelerator(self, tensor):
return getattr(tensor.device, 'type', None) == 'xla'

def op_builder_dir(self):
return "deepspeed.ops.op_builder.cpu"

def create_op_builder(self, op_name):
return None

def get_op_builder(self, class_name):
return None

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension

def export_envs(self):
return ['PJRT_DEVICE', 'TPU_VISIBLE_CHIPS']

def visible_devices_envs(self):
return ['TPU_VISIBLE_CHIPS']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))
current_env.setdefault('PJRT_DEVICE', 'TPU')

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
if backend is not None:
raise ValueError(f"{backend} not supported by {self.device_name()}. Supported Backends are [None]")
7 changes: 6 additions & 1 deletion deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method):
utils.logger.info(f"Initialize {ds_backend} backend")
elif ds_backend == HCCL_BACKEND:
utils.logger.debug("HCCL backend in DeepSpeed not yet implemented")
elif ds_backend == XLA_BACKEND:
utils.logger.debug("XLA backend in DeepSpeed is provided via torch.distributed")
else:
utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend")

Expand Down Expand Up @@ -821,6 +823,8 @@ def init_distributed(dist_backend=None,
utils.logger.info(f'cdb={cdb}')
if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
if dist_backend is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this behavior? Is it specific to xla?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for the call out! I added a comment clarifying that it's a general fix (not XLA-specific) — it prevents passing None to TorchBackend when the user pre-initialized torch.distributed without specifying dist_backend.

dist_backend = get_accelerator().communication_backend_name()
cdb = TorchBackend(dist_backend, timeout, init_method)
return

Expand All @@ -831,7 +835,8 @@ def init_distributed(dist_backend=None,
else:
# Initialize torch distributed if needed
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
xla_backend = dist_backend == XLA_BACKEND or get_accelerator().communication_backend_name() == XLA_BACKEND
if auto_mpi_discovery and not xla_backend and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
if in_aml() and not in_dlts():
Expand Down
1 change: 1 addition & 0 deletions deepspeed/comm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
GLOO_BACKEND = 'gloo'
SCCL_BACKEND = 'sccl'
HCCL_BACKEND = 'hccl'
XLA_BACKEND = 'xla'

DEFAULT_AML_MASTER_PORT = "54965"
DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo"
Expand Down
11 changes: 11 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def has_reduce_scatter_tensor(self):

def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized():
if backend == XLA_BACKEND:
import torch_xla.distributed.xla_backend # noqa: F401
if init_method is None:
init_method = "xla://"
kwargs = dict(timeout=timeout, init_method=init_method, rank=rank, world_size=world_size)

# 1. device_id arg was added in torch==2.3
Expand All @@ -159,6 +163,13 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size):
kwargs.update(device_id=get_accelerator().device(local_rank))
torch.distributed.init_process_group(backend, **kwargs)

if backend == XLA_BACKEND:
os.environ.setdefault('RANK', str(torch.distributed.get_rank()))
os.environ.setdefault('WORLD_SIZE', str(torch.distributed.get_world_size()))
if 'LOCAL_RANK' not in os.environ:
import torch_xla.core.xla_model as xm
os.environ['LOCAL_RANK'] = str(xm.get_local_ordinal())

self.using_mpi = torch.distributed.get_backend() == 'mpi'

@disable_compiler_collective
Expand Down
11 changes: 4 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,11 @@ def input(msg):


def split_half_float_double(tensors):
device_type = get_accelerator().device_name()
dtypes = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type)
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
accelerator = get_accelerator()
dtype_order = (torch.float16, torch.float32, torch.float64, torch.bfloat16)
for dtype in dtype_order:
bucket = [tensor for tensor in tensors if tensor.dtype == dtype and accelerator.on_accelerator(tensor)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does xla require special handling given the prior code worked for other accelerators?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue is that it seems on_accelerator is only defined for xla and so other accelerators will break here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for the comments:) I removed the accelerator.on_accelerator(tensor) filter that would have changed behavior for all backends. I also kept the dtype-based comparison (replacing the old string-based type names like torch.cuda.HalfTensor) since that's the actual fix needed for XLA compatibility without breaking other accelerators.

if bucket:
buckets.append(bucket)
return buckets
Expand Down
Loading
Loading