Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions csrc/compile/deepcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id)

if (sync_after_reduce) { c10::cuda::device_synchronize(); }

return at::Tensor();
return torch::empty({0}, grad_tensor.options());
}

at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
{
return at::Tensor();
return torch::empty({0}, grad_tensor.options());
}

void free_tensors(std::vector<at::Tensor> tensors)
Expand Down Expand Up @@ -179,10 +179,12 @@ void start_backward(bool update)
for (auto& it : executors) { it.second->startBackward(update); }
}

void end_backward(long graph_id)
void end_backward(const c10::IValue& deps, long graph_id)
{
auto executor = getExecutor<CustomOpExecutor>(graph_id, executors);
executor->endBackward();
}

void end_backward_meta(const c10::IValue& deps, long graph_id) {}

} // namespace dc
9 changes: 6 additions & 3 deletions csrc/compile/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TORCH_LIBRARY(dc, m)
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
m.def("offload_parameter(Tensor a, int id, int id) -> ()");
m.def("reload_parameter(Tensor a, int id, int id) -> ()");
m.def("end_backward(int graph_id) -> ()");
m.def("end_backward(Any deps, int graph_id) -> ()");

m.def("test_call(Tensor a) -> Tensor");
}
Expand All @@ -43,6 +43,7 @@ TORCH_LIBRARY_IMPL(dc, CPU, m)
m.impl("wait_reload", &dc::wait_reload);
m.impl("offload_parameter", &dc::offload_parameter);
m.impl("reload_parameter", &dc::reload_parameter);
m.impl("end_backward", &dc::end_backward);

m.impl("test_call", &dc::test_call);
}
Expand All @@ -61,6 +62,7 @@ TORCH_LIBRARY_IMPL(dc, CUDA, m)
m.impl("wait_reload", &dc::wait_reload);
m.impl("offload_parameter", &dc::offload_parameter);
m.impl("reload_parameter", &dc::reload_parameter);
m.impl("end_backward", &dc::end_backward);

m.impl("test_call", &dc::test_call);
}
Expand All @@ -75,10 +77,11 @@ TORCH_LIBRARY_IMPL(dc, Meta, m)
m.impl("free_tensors", &dc::free_tensors_meta);
m.impl("reload_parameter", &dc::reload_parameter_meta);
m.impl("offload_parameter", &dc::offload_parameter_meta);
m.impl("end_backward", &dc::end_backward_meta);
}

// The "Undefined" dispatch key is for operations whose arguments do not contain
// a tensor.
// end_backward may be invoked with dependency placeholders that have already
// become None, in which case the dispatcher sees no tensor arguments.
TORCH_LIBRARY_IMPL(dc, Undefined, m) { m.impl("end_backward", &dc::end_backward); }

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
Expand Down
3 changes: 2 additions & 1 deletion csrc/compile/z3.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,6 @@ void reload_parameter(at::Tensor tensor, long graph_id, long id);
void offload_parameter(at::Tensor tensor, long graph_id, long id);
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void end_backward(long graph_id);
void end_backward(const c10::IValue& deps, long graph_id);
void end_backward_meta(const c10::IValue& deps, long graph_id);
} // namespace dc
18 changes: 18 additions & 0 deletions deepspeed/compile/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from torch.fx import Node, Graph, GraphModule
from torch.fx.node import map_aggregate

from .util import get_last_uses

Expand All @@ -19,6 +20,23 @@ def get_output_node(graph: Graph):
raise ValueError("No output node found")


def add_end_backward(graph: Graph, graph_id: int):
reduce_nodes = [n for n in graph.nodes if n.target == torch.ops.dc.reduce_grad.default]
if len(reduce_nodes) == 0:
return

with graph.inserting_before(get_output_node(graph)):
graph.create_node("call_function", torch.ops.dc.end_backward.default, (reduce_nodes, graph_id))


def replace_reduce_outputs_with_none(graph: Graph):
output_node = get_output_node(graph)
new_outputs = map_aggregate(
output_node.args[0], lambda n: None
if isinstance(n, Node) and n.target == torch.ops.dc.reduce_grad.default else n)
output_node.args = (new_outputs, )


def move_primals_to_head(graph: Graph):

# Move primals to the head of the graph
Expand Down
1 change: 1 addition & 0 deletions deepspeed/compile/inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def register_fallback_no_reuse(op_overload,
never_reuse_output=True,
force_free_input=True)
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
register_fallback_no_reuse(torch.ops.dc.end_backward.default, never_reuse_input=True, never_reuse_output=False)

if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
Scheduler.is_dc_patched = True
Expand Down
119 changes: 92 additions & 27 deletions deepspeed/compile/passes/selective_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# DeepSpeed Team

from collections import defaultdict
from typing import List, Tuple
from typing import Dict, List, Tuple

import torch
from torch.fx import GraphModule

import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import log_dist

from ..util import get_deepcompile_handle
from ..graph_param import DSGraphParamManager
Expand All @@ -19,11 +20,44 @@

max_alloc_mem = 0
last_optimize_step = 0
MEM_MARGIN = 0.1


def print_rank_0(message):
log_dist(message, ranks=[0])


def _compute_persistence_budget(all_graph_mem_records: List[List[Tuple[str, int, int, int]]], total_mem: int,
mem_margin: float) -> Dict[str, int]:
usable_mem = int(total_mem * (1 - mem_margin))
non_empty_records = [mem_records for mem_records in all_graph_mem_records if mem_records]

if not non_empty_records:
return {
"usable_mem": usable_mem,
"peak_resident_alloc": 0,
"transient_peak": 0,
"available_mem": 0,
"profiled_list_count": 0,
}

# Persistent parameters add to live allocations that remain resident past an op boundary.
peak_resident_alloc = max(record[1] for mem_records in non_empty_records for record in mem_records)
transient_peak = max(record[3] for mem_records in non_empty_records for record in mem_records)

return {
"usable_mem": usable_mem,
"peak_resident_alloc": peak_resident_alloc,
"transient_peak": transient_peak,
"available_mem": max(0, usable_mem - peak_resident_alloc),
"profiled_list_count": len(non_empty_records),
}


def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
target_graph_id = graph_id

if not bwd:
return gm
Expand All @@ -38,19 +72,21 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
if last_backward_graph_id is None or graph_id != last_backward_graph_id:
return gm

peak_mem = 0
for graph_id, prof in profiling_results.items():
# Use peak memory
fwd_max_mem = max(m[3] for m in prof.fwd_mem)
bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
if dist.get_rank() == 0:
print(
f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}"
)
all_graph_mem_records = []
for profile_graph_id, prof in profiling_results.items():
all_graph_mem_records.extend([prof.fwd_mem, prof.bwd_mem])

fwd_peak_resident = max((m[1] for m in prof.fwd_mem), default=0)
fwd_transient_peak = max((m[3] for m in prof.fwd_mem), default=0)
bwd_peak_resident = max((m[1] for m in prof.bwd_mem), default=0)
bwd_transient_peak = max((m[3] for m in prof.bwd_mem), default=0)

print_rank_0(f"selective_gather graph_id={profile_graph_id} "
f"fwd_peak_resident={fwd_peak_resident} fwd_transient_peak={fwd_transient_peak} "
f"bwd_peak_resident={bwd_peak_resident} bwd_transient_peak={bwd_transient_peak}")

persistent_ds_ids = set()
for graph_id, pm in param_manager.items():
for param_graph_id, pm in param_manager.items():
for name, ds_param in pm.params.items():
if ds_param.param.ds_persist:
persistent_ds_ids.add(pm.ds_ids[name])
Expand All @@ -60,13 +96,13 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
ds_id_to_prof_dtime = defaultdict(float)
ds_id_to_prof_wtime = defaultdict(float)

for graph_id, pm in param_manager.items():
for param_graph_id, pm in param_manager.items():
params = pm.params
for param_name, param in params.items():
ds_id = pm.ds_ids[param_name]
ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize

profile = profiling_results[graph_id]
profile = profiling_results[param_graph_id]
for n in profile.fwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
Expand Down Expand Up @@ -100,39 +136,68 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int
# f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
# )

sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids}

accelerator = get_accelerator()
total_mem = accelerator.total_memory()
vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device()))
current_available_mem = accelerator.available_memory()
vals_to_bcast = torch.tensor([total_mem, current_available_mem],
device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
total_mem = vals_to_bcast[0].item()
current_available_mem = vals_to_bcast[1].item()

MEM_MARGIN = 0.1
available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem

if dist.get_rank() == 0:
print(
f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}"
)
budget = _compute_persistence_budget(all_graph_mem_records, total_mem, MEM_MARGIN)
available_mem = int(current_available_mem * (1 - MEM_MARGIN))

ds_id_to_param = {}
for g_id, g_pm in param_manager.items():
for name, ds_param in g_pm.params.items():
ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param

candidate_bytes = sum(ds_id_to_size[ds_id] for ds_id in ds_ids)
persistent_bytes = sum(ds_id_to_size.get(ds_id, 0) for ds_id in persistent_ds_ids)

print_rank_0(
f"selective_gather target_graph_id={target_graph_id} profiled_mem_lists={budget['profiled_list_count']} "
f"total_mem={total_mem} usable_mem={budget['usable_mem']} peak_resident_alloc={budget['peak_resident_alloc']} "
f"transient_peak={budget['transient_peak']} current_available_mem={current_available_mem} "
f"usable_available_mem={available_mem} "
f"persistent_count={len(persistent_ds_ids)} persistent_bytes={persistent_bytes} "
f"candidate_count={len(ds_ids)} candidate_bytes={candidate_bytes}")

if budget["profiled_list_count"] == 0:
print_rank_0("selective_gather no profiling data; skipping persistence update")
return gm

if len(ds_ids) == 0:
print_rank_0("selective_gather no candidates to persist")
return gm

if available_mem == 0:
print_rank_0("selective_gather no currently available memory for new persistent params")
return gm

persistent_mem = 0
selected_count = 0
nz3 = get_deepcompile_handle()
for ds_id, size in sorted_ds_ids.items():
for ds_id in ds_ids:
size = ds_id_to_size[ds_id]
if persistent_mem + size > available_mem:
break
persistent_mem += size
selected_count += 1

param_obj = ds_id_to_param[ds_id]

nz3.set_persistent(ds_id)
if dist.get_rank() == 0:
print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
print_rank_0(
f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")

if selected_count == 0:
smallest_candidate = min(ds_id_to_size[ds_id] for ds_id in ds_ids)
print_rank_0(f"selective_gather selected no new params: available_mem={available_mem} "
f"smallest_candidate={smallest_candidate}")
else:
print_rank_0(f"selective_gather selected_count={selected_count} selected_bytes={persistent_mem}")

return gm

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/compile/passes/zero1_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.fx import GraphModule

from ..util import get_deepcompile_handle
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, get_output_node
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta, add_end_backward, replace_reduce_outputs_with_none

NAME = "zero1_compile"

Expand Down Expand Up @@ -50,8 +50,8 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu

gm.graph = move_primals_to_head(graph)

with gm.graph.inserting_before(get_output_node(gm.graph)):
gm.graph.create_node("call_function", torch.ops.dc.end_backward.default, (graph_id, ))
add_end_backward(gm.graph, graph_id)
replace_reduce_outputs_with_none(gm.graph)

return gm

Expand Down
6 changes: 3 additions & 3 deletions deepspeed/compile/passes/zero3_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.fx import Graph, Node, GraphModule

from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head, add_end_backward, replace_reduce_outputs_with_none
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule

Expand Down Expand Up @@ -209,8 +209,8 @@ def add_z3_gather_release_bw(gm: GraphModule,
0, # unused
debug_log=debug_log)

with gm.graph.inserting_before(get_output_node(gm.graph)):
gm.graph.create_node("call_function", torch.ops.dc.end_backward.default, (graph_id, ))
add_end_backward(gm.graph, graph_id)
replace_reduce_outputs_with_none(gm.graph)

return gm

Expand Down
23 changes: 23 additions & 0 deletions deepspeed/compile/patch_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def wrap_if_ds_param(t):
return t


def _get_guard_sizes_strides(t):
if hasattr(t, "ds_id"):
# ZeRO-3 may temporarily all-gather a parameter during tracing, but the
# stable module state used by TorchDynamo guards is the released
# partitioned form, where DeepSpeed resets param.data to empty(0).
released = torch.empty(0, dtype=t.dtype, device=t.device)
return released.size(), released.stride()

return t.size(), t.stride()


def patch_fake_tensor():
# dynamo tracer uses wrap_to_fake_tensor_and_record
# Wrapping FakeTensorMode.from_tensor is not sufficient as dynamo generates SymbolicContext before calling from_tensor
Expand All @@ -37,8 +48,20 @@ def patch_fake_tensor():
def wrap_to_fake_tensor_and_record_wrapper(t, *args, **kwargs):
dummy_tensor = wrap_if_ds_param(t)
ret = original_wrap_to_fake_tensor_and_record(dummy_tensor, *args, **kwargs)
tx = kwargs.get("tx") if "tx" in kwargs else args[0]
source = kwargs.get("source")
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.tensor_to_context[t] = tracing_context.tensor_to_context.pop(dummy_tensor)
if source is not None:
# Keep the full ds_shape symbolic context from the dummy tensor, but
# use the stable released ZeRO-3 parameter representation for
# TorchDynamo's tensor-match guards. PyTorch 2.9 started enforcing
# those guards for parameters during build_guards().
size, stride = _get_guard_sizes_strides(t)
tx.output.input_source_to_sizes_strides[source] = {
"size": size,
"stride": stride,
}
return ret

torch._dynamo.variables.builder.wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record_wrapper
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/compile/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


def is_deepcompile_supported() -> bool:
return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda"
return required_torch_version(min_version=2.6) and get_accelerator().device_name() == "cuda"


dc_handle = None
Expand Down
Loading
Loading