Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
run_multi_turn_rollout,
)
from nemo_rl.models.generation.interfaces import GenerationInterface
from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration
from nemo_rl.models.generation.redesign.config import SGLangConfig
from nemo_rl.models.generation.redesign.sglang_generation import SGLangGeneration
from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface
Expand Down Expand Up @@ -438,6 +439,10 @@ def init_train_dataloader(dataset, suffix: str = ""):
)
train_cluster = cluster
inference_cluster = cluster
inference_cluster_cfg: ClusterConfig = {
"gpus_per_node": policy_gpus_per_node,
"num_nodes": policy_nodes,
}
print(
f" ✓ Ray cluster for policy initialized with {policy_nodes} nodes",
flush=True,
Expand Down Expand Up @@ -526,6 +531,10 @@ def init_train_dataloader(dataset, suffix: str = ""):
num_gpus_per_node=inference_gpus_per_node,
max_colocated_worker_groups=1,
)
inference_cluster_cfg: ClusterConfig = {
"gpus_per_node": inference_gpus_per_node,
"num_nodes": inference_nodes,
}
print(
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node",
flush=True,
Expand Down Expand Up @@ -578,7 +587,11 @@ def init_vllm():
def init_sglang():
"""Initialize SGLang generation workers."""
t0 = time.perf_counter()
pg = SGLangGeneration(cluster=inference_cluster, config=generation_config)
pg = SGLangGeneration(
cluster=inference_cluster,
cluster_cfg=inference_cluster_cfg,
sglang_cfg=generation_config,
)
pg.finish_generation()
return pg, time.perf_counter() - t0

Expand Down Expand Up @@ -1139,15 +1152,11 @@ def refit_policy_generation(
)

if isinstance(policy_generation, SGLangGeneration):
sglang_url_to_gpu_uuids = (
policy_generation.get_sglang_url_to_gpu_uuids()
)
# Stream weights via HTTP
flush_success = policy_generation.invalidate_kv_cache()
if not flush_success:
print("SGLang KV cache invalidation failed before weight update. ")
# Stream weights to colocated SGLang engines via CUDA IPC over HTTP.
# Engine-i owns global ranks [i*K, (i+1)*K) where K = num_gpus_per_engine.
futures_train = policy.stream_weights_via_http(
sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids,
rollout_engines=policy_generation.rollout_engines,
num_gpus_per_engine=policy_generation.num_gpus_per_engine,
)
# Wait for all workers to complete
ray.get(futures_train)
Expand Down
1 change: 1 addition & 0 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE,
"nemo_rl.models.generation.redesign.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE,
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.FSDP,
"nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
Expand Down
48 changes: 48 additions & 0 deletions nemo_rl/distributed/virtual_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,54 @@ def get_gpu_id(self):
return ray.get_gpu_ids()[0]


def get_reordered_bundle_and_gpu_ids(
pg: PlacementGroup,
) -> tuple[list[int], list[int]]:
"""Return bundle indices and GPU IDs sorted by (node_id, gpu_id).

Uses ``GetGPUIDActor`` to discover the physical GPU ID assigned to each
bundle, then sorts identically to the pattern in ``pg.py``.

Returns:
(reordered_bundle_indices, reordered_gpu_ids)
"""
pg_data = placement_group_table(pg)
num_bundles = len(pg_data["bundles"])
bundle_to_node_ids = pg_data["bundles_to_node_id"]

info_actors = []
for i in range(num_bundles):
info_actors.append(
GetGPUIDActor.options(
num_cpus=0.01,
num_gpus=0.01,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=i,
),
).remote()
)

gpu_ids = ray.get([actor.get_gpu_id.remote() for actor in info_actors])
for actor in info_actors:
ray.kill(actor)

bundle_infos = [(i, bundle_to_node_ids[i], gpu_ids[i]) for i in range(num_bundles)]
sorted_infos = sorted(bundle_infos, key=lambda x: (x[1], x[2]))

reordered_bundle_indices = [info[0] for info in sorted_infos]
reordered_gpu_ids = [gpu_ids[info[0]] for info in sorted_infos]

for i, info in enumerate(sorted_infos):
actual_idx = info[0]
logger.info(
f" bundle {i:4}, actual_bundle_index: {actual_idx:4}, "
f"node: {info[1]}, gpu: {gpu_ids[actual_idx]}"
)

return reordered_bundle_indices, reordered_gpu_ids


class ResourceInsufficientError(Exception):
"""Exception raised when the cluster does not have enough resources to satisfy the requested configuration."""

Expand Down
33 changes: 33 additions & 0 deletions nemo_rl/models/generation/redesign/async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
import threading


class AsyncLoopThread:
def __init__(self):
self.loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._start_loop, daemon=True)
self._thread.start()

def _start_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

def run(self, coro):
# Schedule a coroutine onto the loop and block until it's done
return asyncio.run_coroutine_threadsafe(coro, self.loop).result()


# Create one global instance
async_loop = None


def get_async_loop():
global async_loop
if async_loop is None:
async_loop = AsyncLoopThread()
return async_loop


def run(coro):
"""Run a coroutine in the background event loop."""
return get_async_loop().run(coro)
122 changes: 122 additions & 0 deletions nemo_rl/models/generation/redesign/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, NotRequired, TypedDict

from nemo_rl.models.generation.interfaces import GenerationConfig

class SglangSpecificArgs(TypedDict):
"""SGLang-specific configuration arguments.

Most fields below map directly to SGLang's ServerArgs (see:
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py).
"""

model_path: NotRequired[str]
# Total number of gpus for rollout
random_seed: NotRequired[int]
skip_tokenizer_init: NotRequired[bool]
disable_cuda_graph: NotRequired[bool]
disable_radix_cache: NotRequired[bool]
disable_cuda_graph_padding: NotRequired[bool]
# Enabling piecewise CUDA graph (i.e. setting this to False) currently crashes with
# "illegal memory access", likely due to torch 2.10 + sglang incompatibility.
# Defaulted to True (disabled) in sglang_worker.py until the upstream sglang fork is updated.
disable_piecewise_cuda_graph: NotRequired[bool]
enable_nccl_nvls: NotRequired[bool]
disable_outlines_disk_cache: NotRequired[bool]
disable_custom_all_reduce: NotRequired[bool]
disable_overlap_schedule: NotRequired[bool]
enable_mixed_chunk: NotRequired[bool]
enable_dp_attention: NotRequired[bool]
enable_deepep_moe: NotRequired[bool]
enable_ep_moe: NotRequired[bool]
enable_torch_compile: NotRequired[bool]
torch_compile_max_bs: NotRequired[int]
cuda_graph_max_bs: NotRequired[int | None]
cuda_graph_bs: NotRequired[list[int] | None]
torchao_config: NotRequired[str]
enable_nan_detection: NotRequired[bool]
enable_p2p_check: NotRequired[bool]
triton_attention_reduce_in_fp32: NotRequired[bool]
triton_attention_num_kv_splits: NotRequired[int]
num_continuous_decode_steps: NotRequired[int]
allow_auto_truncate: NotRequired[bool]
attention_backend: NotRequired[str | None]
enable_multimodal: NotRequired[bool]
sampling_backend: NotRequired[str | None]
context_length: NotRequired[int | None]
mem_fraction_static: NotRequired[float | None]
max_running_requests: NotRequired[int | None]
chunked_prefill_size: NotRequired[int | None]
max_prefill_tokens: NotRequired[int]
schedule_policy: NotRequired[str]
schedule_conservativeness: NotRequired[float]
cpu_offload_gb: NotRequired[int]
dtype: NotRequired[str]
kv_cache_dtype: NotRequired[str]
dp_size: NotRequired[int] # only used for dp attention
pp_size: NotRequired[int] # pipeline parallel size
ep_size: NotRequired[int]
# lora
enable_lora: NotRequired[bool | None]
max_lora_rank: NotRequired[int | None]
lora_target_modules: NotRequired[list[str] | None]
lora_paths: NotRequired[list[str] | None]
max_loaded_loras: NotRequired[int]
max_loras_per_batch: NotRequired[int]
lora_backend: NotRequired[str]
# logging
log_level: NotRequired[str]
log_level_http: NotRequired[str | None]
log_requests: NotRequired[bool]
log_requests_level: NotRequired[int]
show_time_cost: NotRequired[bool]
enable_metrics: NotRequired[bool] # Exports Prometheus-like metrics
# The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: NotRequired[int]
# Extra loader arguments
enable_multithread_load: NotRequired[bool]
enable_fast_load: NotRequired[bool]
# Server warmup
skip_server_warmup: NotRequired[bool]
# Fault tolerance
use_fault_tolerance: NotRequired[bool]
rollout_health_check_interval: NotRequired[int]
rollout_health_check_timeout: NotRequired[int]
rollout_health_check_first_wait: NotRequired[int]

class SGLangServer(TypedDict):
# needs_offload true --> enable_memory_saver true
needs_offload: bool
# for testing purpose. memory_saver
cpu_weight_backup: bool
sglang_server_concurrency: int
# total num gpus for inference
num_gpus: NotRequired[int]
num_gpus_per_engine: NotRequired[int]

class SGLangRouter(TypedDict):
sglang_router_ip: NotRequired[str]
sglang_router_port: NotRequired[int]
router_policy: NotRequired[str]
use_distributed_post: NotRequired[bool]

class SGLangConfig(GenerationConfig):
"""Configuration for SGLang runtime."""
sglang_cfg: SglangSpecificArgs
sglang_server: SGLangServer
sglang_router: SGLangRouter
sglang_kwargs: NotRequired[dict[str, Any]]
Loading
Loading