From 51d050595a2f87ab4f2034638359cc429a500cef Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Thu, 2 Apr 2026 15:51:36 -0700 Subject: [PATCH 01/36] update --- .../generation/redesign/sglang_worker.py | 659 ++++++++++++++++++ 1 file changed, 659 insertions(+) create mode 100644 nemo_rl/models/generation/redesign/sglang_worker.py diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py new file mode 100644 index 0000000000..3ba526d37e --- /dev/null +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -0,0 +1,659 @@ +import dataclasses +import ipaddress +import logging +import multiprocessing +import os +import time +from urllib.parse import quote + +import requests +import sglang_router +from packaging.version import parse +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree +from urllib3.exceptions import NewConnectionError +from miles.utils.env_report import collect_and_print_node_env_report +import ray + +logger = logging.getLogger(__name__) + +def get_base_gpu_id(args, rank): + num_gpus = min(args.num_gpus_per_node, args.rollout_num_gpus_per_engine) + if args.colocate: + start_index = (rank * num_gpus) % args.num_gpus_per_node + else: + num_actor_gpus = 0 if args.debug_rollout_only else args.actor_num_gpus_per_node * args.actor_num_nodes + start_index = (num_actor_gpus + rank * num_gpus) % args.num_gpus_per_node + if args.use_critic: + num_critic_gpus = args.critic_num_gpus_per_node * args.critic_num_nodes + start_index = (num_actor_gpus + num_critic_gpus + rank * num_gpus) % args.num_gpus_per_node + return start_index + +def _to_local_gpu_id(physical_gpu_id: int) -> int: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if not cvd: + return physical_gpu_id # no remapping + # CUDA_VISIBLE_DEVICES can be like "4,5,6,7" + visible = [int(x) for x in cvd.split(",") if x.strip() != ""] + # In a remapped process, valid torch device indices are 0..len(visible)-1 + if physical_gpu_id in visible: + return visible.index(physical_gpu_id) + # If we're already getting local IDs, allow them + if 0 <= physical_gpu_id < len(visible): + return physical_gpu_id + raise RuntimeError( + f"GPU id {physical_gpu_id} is not valid under CUDA_VISIBLE_DEVICES={cvd}. " + f"Expected one of {visible} (physical) or 0..{len(visible)-1} (local)." + ) + +def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: + from sglang.srt.entrypoints.http_server import launch_server + + multiprocessing.set_start_method("spawn", force=True) + server_args.host = server_args.host.strip("[]") + p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p.start() + + if server_args.node_rank != 0: + return + + _wait_server_healthy( + base_url=server_args.url(), + api_key=server_args.api_key, + is_process_alive=lambda: p.is_alive(), + ) + + return p + + +def _wait_server_healthy(base_url, api_key, is_process_alive): + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {api_key}", + } + + with requests.Session() as session: + while True: + try: + response = session.get(f"{base_url}/health_generate", headers=headers) + if response.status_code == 200: + break + except requests.RequestException: + pass + + if not is_process_alive(): + raise Exception("Server process terminated unexpectedly.") + + time.sleep(2) + + # use flush_cache to make sure the working queue is empty, so that we can do offload + while True: + try: + response = session.get(f"{base_url}/flush_cache", headers=headers) + if response.status_code == 200: + break + + except requests.RequestException: + pass + + if not is_process_alive(): + raise Exception("Server process terminated unexpectedly.") + + time.sleep(2) + +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except OSError: + return False + except OverflowError: + return False + +def get_host_info(): + hostname = socket.gethostname() + + if env_overwrite_local_ip := os.getenv(MILES_HOST_IP_ENV, None): + return hostname, env_overwrite_local_ip + + def _is_loopback(ip): + return ip.startswith("127.") or ip == "::1" + + def _resolve_ip(family, test_target_ip): + """ + Attempt to get the local LAN IP for the specific family (IPv4/IPv6). + Strategy: UDP Probe (Preferred) -> Hostname Resolution (Fallback) -> None + """ + + # Strategy 1: UDP Connect Probe (Most accurate, relies on routing table) + # Useful when the machine has a default gateway or internet access. + try: + with socket.socket(family, socket.SOCK_DGRAM) as s: + # The IP doesn't need to be reachable, but the routing table must exist. + s.connect((test_target_ip, 80)) + ip = s.getsockname()[0] + if not _is_loopback(ip): + return ip + except Exception: + pass # Route unreachable or network error, move to next strategy. + + # Strategy 2: Hostname Resolution (Fallback for offline clusters) + # Useful for offline environments where UDP connect fails but /etc/hosts is configured. + try: + # getaddrinfo allows specifying the family (AF_INET or AF_INET6) + # Result format: [(family, type, proto, canonname, sockaddr), ...] + infos = socket.getaddrinfo(hostname, None, family=family, type=socket.SOCK_STREAM) + + for info in infos: + ip = info[4][0] # The first element of sockaddr is the IP + # Must filter out loopback addresses to avoid "127.0.0.1" issues + if not _is_loopback(ip): + return ip + except Exception: + pass + + return None + + prefer_ipv6 = os.getenv("MILES_PREFER_IPV6", "0").lower() in ("1", "true", "yes", "on") + local_ip = None + final_fallback = "127.0.0.1" + + if prefer_ipv6: + # [Strict Mode] IPv6 Only + # 1. Try UDP V6 Probe + # 2. Try Hostname Resolution (V6) + # If failed, fallback to V6 loopback. Never mix with V4. + local_ip = _resolve_ip(socket.AF_INET6, "2001:4860:4860::8888") + final_fallback = "::1" + else: + # [Strict Mode] IPv4 Only (Default) + # 1. Try UDP V4 Probe + # 2. Try Hostname Resolution (V4) + # If failed, fallback to V4 loopback. Never mix with V6. + local_ip = _resolve_ip(socket.AF_INET, "8.8.8.8") + final_fallback = "127.0.0.1" + + return hostname, local_ip or final_fallback + +def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + address = address.strip("[]") + return address + +def get_free_port(start_port=10000, consecutive=1): + # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available + port = start_port + while not all(is_port_available(port + i) for i in range(consecutive)): + port += 1 + return port + +class SGLangEngine: + def __init__( + self, + args, + rank: int, + worker_type: str = "regular", + base_gpu_id: int | None = None, + sglang_overrides: dict | None = None, + num_gpus_per_engine: int | None = None, + ): + self.args = args + self.rank = rank + self.worker_type = worker_type + self.base_gpu_id = base_gpu_id + self.sglang_overrides = sglang_overrides or {} + self.num_gpus_per_engine = num_gpus_per_engine + + def init( + self, + dist_init_addr, + port, + nccl_port, + host=None, + disaggregation_bootstrap_port=None, + router_ip=None, + router_port=None, + ): + if env_report := self.args.env_report: + collect_and_print_node_env_report( + role="rollout", + rank=self.rank, + partial_env_report=env_report, + ) + + self.router_ip = router_ip if router_ip is not None else self.args.sglang_router_ip + self.router_port = router_port if router_port is not None else self.args.sglang_router_port + + host = host or get_host_info()[1] + + def _format_v6_uri(addr): + if not addr or addr.startswith("["): + return addr + try: + if ipaddress.ip_address(addr).version == 6: + return f"[{addr}]" + except ValueError: + pass + return addr + + host = _format_v6_uri(host) + ip_part, port_part = dist_init_addr.rsplit(":", 1) + dist_init_addr = f"{_format_v6_uri(ip_part)}:{port_part}" + + server_args_dict, _ = _compute_server_args( + self.args, + self.rank, + dist_init_addr, + nccl_port, + host, + port, + self.worker_type, + disaggregation_bootstrap_port, + base_gpu_id=self.base_gpu_id, + sglang_overrides=self.sglang_overrides, + num_gpus_per_engine=self.num_gpus_per_engine, + ) + + self.node_rank = server_args_dict["node_rank"] + self.server_host = server_args_dict["host"] # with [] if ipv6 + self.server_port = server_args_dict["port"] + + self._init_normal(server_args_dict) + + + def _init_normal(self, server_args_dict): + logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") + self.process = launch_server_process(ServerArgs(**server_args_dict)) + + if self.node_rank == 0 and self.router_ip and self.router_port: + if parse(sglang_router.__version__) <= parse("0.2.1") or self.args.use_miles_router: + assert ( + self.worker_type == "regular" + ), "pd disaggregation is not supported in old router or miles router." + response = requests.post( + f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}" + ) + else: + payload = { + "url": f"http://{self.server_host}:{self.server_port}", + "worker_type": self.worker_type, + } + response = requests.post( + f"http://{self.router_ip}:{self.router_port}/workers", + json=payload, + ) + response.raise_for_status() + + def _make_request(self, endpoint: str, payload: dict | None = None): + """Make a POST request to the specified endpoint with the given payload. + + Args: + endpoint: The API endpoint to call + payload: The JSON payload to send (default: empty dict) + + Returns: + The JSON response from the server + """ + if self.node_rank != 0: + return + + url = f"http://{self.server_host}:{self.server_port}/{endpoint}" + response = requests.post(url, json=payload or {}) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + e.add_note(f"{response.text=}") + raise + return response.json() + + @staticmethod + def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): + return get_current_node_ip(), get_free_port(start_port=start_port, consecutive=consecutive) + + def get_master_addr_and_port(self): + return self.master_addr, self.master_port + + def health_generate(self, timeout: float = 5.0) -> bool: + """Run /health_generate on the underlying SGLang HTTP server. + + Args: + timeout: Timeout for the health request in seconds. + + Returns: + True if the server responds with HTTP 200. + + Raises: + requests.RequestException: If the request fails for any reason, including timeout. + """ + if self.node_rank != 0: + return True + + response = requests.get( + f"http://{self.server_host}:{self.server_port}/health_generate", + timeout=timeout, + ) + response.raise_for_status() + return True + + def update_weights_from_tensor( + self, + serialized_named_tensors: list[str], + load_format: str | None = None, + flush_cache: bool = False, + weight_version: str | None = None, + ): + """ + Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs. + + Note: The model should be on GPUs rather than CPU for this functionality to work properly. + If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. + """ + payload = { + "serialized_named_tensors": serialized_named_tensors, + "load_format": load_format, + "flush_cache": flush_cache, + } + if weight_version is not None: + payload["weight_version"] = weight_version + return self._make_request( + "update_weights_from_tensor", + payload, + ) + + def flush_cache(self): + """Flush the cache of the server.""" + if self.node_rank != 0: + return + # flush cache will not return status_code 200 when there are pending requests + for _ in range(60): + try: + response = requests.get(f"http://{self.server_host}:{self.server_port}/flush_cache") + if response.status_code == 200: + break + except NewConnectionError as e: + raise e + except Exception as e: + logger.info(f"Error flushing cache: {e}") + time.sleep(1) + continue + else: + raise TimeoutError("Timeout while flushing cache.") + + def shutdown(self): + + logger.info(f"Shutdown engine {self.server_host}:{self.server_port}...") + if self.node_rank == 0: + worker_url = f"http://{self.server_host}:{self.server_port}" + response = None + if parse(sglang_router.__version__) <= parse("0.2.1") or self.args.use_miles_router: + response = requests.post( + f"http://{self.router_ip}:{self.router_port}/remove_worker?url=http://{self.server_host}:{self.server_port}" + ) + elif parse(sglang_router.__version__) < parse("0.3.0"): + worker_url = quote(worker_url, safe="") + response = requests.delete(f"http://{self.router_ip}:{self.router_port}/workers/{worker_url}") + else: + try: + all_workers = requests.get(f"http://{self.router_ip}:{self.router_port}/workers").json()["workers"] + for worker in all_workers: + if worker["url"] == worker_url: + worker_id = worker["id"] + response = requests.delete( + f"http://{self.router_ip}:{self.router_port}/workers/{worker_id}" + ) + break + else: + logger.warning(f"Worker {worker_url} not found in router during shutdown.") + except Exception as e: + logger.warning(f"Failed to fetch workers list or remove worker: {e}") + + if response is not None: + response.raise_for_status() + kill_process_tree(self.process.pid) + + def get_weight_version(self): + if self.node_rank != 0: + return + base = f"http://{self.server_host}:{self.server_port}" + # new sglang change api from /get_weight_version to /model_info + for endpoint in ("/model_info", "/get_weight_version"): + response = requests.get(f"{base}{endpoint}") + if response.status_code == 200: + return response.json()["weight_version"] + response.raise_for_status() + + def release_memory_occupation(self, tags: list[str] = None): + """Release memory occupation. Available tags: weights, kv_cache.""" + self.flush_cache() + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) + + def resume_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage resume: weights, kv_cache + """ + return self._make_request( + "resume_memory_occupation", + {"tags": tags}, + ) + + def check_weights(self, action: str): + return self._make_request("weights_checker", {"action": action}) + + def update_weights_from_disk(self, model_path: str, load_format: str | None = None): + """Reload weights from *model_path* without restarting the engine. + + Used for non-updatable (frozen) models that overlap with megatron: + after offload, weights are restored from disk instead of CPU cache. + """ + payload = {"model_path": model_path} + if load_format is not None: + payload["load_format"] = load_format + return self._make_request("update_weights_from_disk", payload) + + def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): + return self._make_request( + "init_weights_update_group", + { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + }, + ) + + def destroy_weights_update_group(self, group_name): + try: + return self._make_request( + "destroy_weights_update_group", + { + "group_name": group_name, + }, + ) + except requests.exceptions.RequestException: + # catch the case there the engine is just created and does not have the group. + pass + + def update_weights_from_distributed( + self, names, dtypes, shapes, group_name, flush_cache=False, weight_version: str | None = None + ): + payload = { + "names": names, + "dtypes": [str(dtype).replace("torch.", "") for dtype in dtypes], + "shapes": shapes, + "group_name": group_name, + "flush_cache": flush_cache, + } + if weight_version is not None: + payload["weight_version"] = weight_version + return self._make_request( + "update_weights_from_distributed", + payload, + ) + + def pause_generation(self): + response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) + response.raise_for_status() + return response + + def continue_generation(self): + response = requests.post(f"http://{self.server_host}:{self.server_port}/continue_generation", json={}) + response.raise_for_status() + return response + + def post_process_weights( + self, + restore_weights_before_load: bool = False, + post_process_quantization: bool = False, + ): + """ + Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs. + Note: The model should be on GPUs rather than CPU for this functionality to work properly. + If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. + """ + + return self._make_request( + "post_process_weights", + { + "restore_weights_before_load": restore_weights_before_load, + "post_process_quantization": post_process_quantization, + }, + ) + + def start_profile( + self, + # The output directory + output_dir: str | None = None, + # If set, it profile as many as this number of steps. + # If it is set, profiling is automatically stopped after this step, and + # the caller doesn't need to run stop_profile. + start_step: int | None = None, + num_steps: int | None = None, + activities: list[str] | None = None, + profile_by_stage: bool = False, + with_stack: bool | None = None, + record_shapes: bool | None = None, + ): + response = requests.post( + f"http://{self.server_host}:{self.server_port}/start_profile", + json={ + "output_dir": output_dir, + "start_step": start_step, + "num_steps": num_steps, + "activities": activities, + "profile_by_stage": profile_by_stage, + "with_stack": with_stack, + "record_shapes": record_shapes, + }, + ) + response.raise_for_status() + return response + + def stop_profile(self): + response = requests.post(f"http://{self.server_host}:{self.server_port}/stop_profile", json={}) + response.raise_for_status() + return response + + def simulate_crash(self): + logger.info(f"Simulating crash on engine {self.server_host}:{self.server_port}...") + self.shutdown() + + +def _compute_server_args( + args, + rank, + dist_init_addr, + nccl_port, + host, + port, + worker_type: str = "regular", + disaggregation_bootstrap_port: int | None = None, + base_gpu_id: int | None = None, + sglang_overrides: dict | None = None, + num_gpus_per_engine: int | None = None, +): + _gpus_per_engine = num_gpus_per_engine or args.rollout_num_gpus_per_engine + nnodes = max(1, _gpus_per_engine // args.num_gpus_per_node) + node_rank = rank % nnodes + base = base_gpu_id if base_gpu_id is not None else get_base_gpu_id(args, rank) + base = _to_local_gpu_id(base) + kwargs = { + "model_path": args.hf_checkpoint, + "trust_remote_code": True, + "random_seed": args.seed + rank, + # memory + "enable_memory_saver": args.offload_rollout, + # distributed + "host": host, + "port": port, + "nccl_port": nccl_port, + "nnodes": nnodes, + "node_rank": node_rank, + "dist_init_addr": dist_init_addr, + "gpu_id_step": 1, + "base_gpu_id": base, + # parallel + "tp_size": _gpus_per_engine, + "dp_size": args.sglang_dp_size, + "pp_size": args.sglang_pp_size, + "ep_size": args.sglang_ep_size, + # always skip warmup to prevent warmup timeout. + "skip_server_warmup": True, + # always enable draft weights cpu backup so that we run training without mtp weights. + "enable_draft_weights_cpu_backup": True, + } + + if sglang_overrides: + kwargs.update(sglang_overrides) + + if args.use_rollout_routing_replay: + kwargs["enable_return_routed_experts"] = True + if args.fp16: + kwargs["dtype"] = "float16" + external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] + + + unused_keys = set(kwargs.keys()) + for attr in dataclasses.fields(ServerArgs): + if hasattr(args, f"sglang_{attr.name}") and attr.name not in kwargs: + kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") + unused_keys.discard(attr.name) + + # for compatibility with old args + if len(unused_keys) > 0: + logger.info(f"Warning: The following arguments is not supported in the current sglang: {unused_keys}.") + for key in unused_keys: + kwargs.pop(key) + + return kwargs, external_engine_need_check_fields + + +_EXTERNAL_ENGINE_SKIP_CHECK_FIELDS = [ + "model_path", + "trust_remote_code", + "random_seed", + "nccl_port", + "dist_init_addr", + "skip_server_warmup", + "enable_draft_weights_cpu_backup", + "mem_fraction_static", +] From a36744bb96d55cc4a843bd7b2280059b3855058c Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 3 Apr 2026 22:24:56 -0700 Subject: [PATCH 02/36] update --- nemo_rl/models/generation/redesign/pg.py | 183 +++ .../generation/redesign/sglang_generation.py | 1048 +++++++++++++++++ .../generation/redesign/sglang_worker.py | 106 -- nemo_rl/models/generation/redesign/utils.py | 107 ++ 4 files changed, 1338 insertions(+), 106 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/pg.py create mode 100644 nemo_rl/models/generation/redesign/sglang_generation.py create mode 100644 nemo_rl/models/generation/redesign/utils.py diff --git a/nemo_rl/models/generation/redesign/pg.py b/nemo_rl/models/generation/redesign/pg.py new file mode 100644 index 0000000000..7cba97e1bd --- /dev/null +++ b/nemo_rl/models/generation/redesign/pg.py @@ -0,0 +1,183 @@ +import logging +import socket + +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from .actor_group import RayTrainGroup +from .rollout import RolloutManager + +logger = logging.getLogger(__name__) + + +def sort_key(x): + index, node_identifier, gpu_id = x + # Sort by node IP number and then by GPU ID + try: + # try to parse it as an IP address. + ip_address = node_identifier + node_ip_parts = list(map(int, ip_address.split("."))) + except ValueError: + # Try to resolve the hostname to an IP address. + try: + ip_address = socket.gethostbyname(node_identifier) + node_ip_parts = list(map(int, ip_address.split("."))) + except (socket.gaierror, TypeError): + # Instead, we convert each character of the original identifier string + # to its ASCII value. This provides a stable and consistent numerical + # representation that allows for sorting. + node_ip_parts = [ord(c) for c in node_identifier] + + return (node_ip_parts, gpu_id) + + +def _create_placement_group(num_gpus): + """Create a placement group with the specified number of GPUs.""" + bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + pg = placement_group(bundles, strategy="PACK") + num_bundles = len(bundles) + + ray.get(pg.ready()) + # use info actor to get the GPU id + info_actors = [] + for i in range(num_bundles): + info_actors.append( + InfoActor.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=i, + ) + ).remote() + ) + gpu_ids = ray.get([actor.get_ip_and_gpu_id.remote() for actor in info_actors]) + for actor in info_actors: + ray.kill(actor) + + bundle_infos = [(i, gpu_ids[i][0], gpu_ids[i][1]) for i in range(num_bundles)] + sorted_bundle_infos = sorted(bundle_infos, key=sort_key) + pg_reordered_bundle_indices = [info[0] for info in sorted_bundle_infos] + # Map from logical index -> physical GPU ID + pg_reordered_gpu_ids = [gpu_ids[info[0]][1] for info in sorted_bundle_infos] + + for i in range(num_bundles): + actual_bundle_index = pg_reordered_bundle_indices[i] + logger.info( + f" bundle {i:4}, actual_bundle_index: {actual_bundle_index:4}, " + f"node: {gpu_ids[actual_bundle_index][0]}, gpu: {gpu_ids[actual_bundle_index][1]}" + ) + + return pg, pg_reordered_bundle_indices, pg_reordered_gpu_ids + + +def create_placement_groups(args): + """Create placement groups for actor and rollout engines.""" + + num_gpus = 0 + if args.debug_train_only: + num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + rollout_offset = 0 + if args.use_critic: + num_gpus += args.critic_num_nodes * args.critic_num_gpus_per_node + critic_offset = args.actor_num_nodes * args.actor_num_gpus_per_node + elif args.debug_rollout_only: + num_gpus = args.rollout_num_gpus + rollout_offset = 0 + elif args.colocate: + num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + rollout_offset = 0 + if args.use_critic: + num_gpus += args.critic_num_nodes * args.critic_num_gpus_per_node + critic_offset = args.actor_num_nodes * args.actor_num_gpus_per_node + else: + num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus + rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node + if args.use_critic: + num_gpus += args.critic_num_nodes * args.critic_num_gpus_per_node + critic_offset = args.actor_num_nodes * args.actor_num_gpus_per_node + rollout_offset += args.critic_num_nodes * args.critic_num_gpus_per_node + + logger.info(f"Creating placement group with {num_gpus} GPUs...") + pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids = _create_placement_group(num_gpus) + + rollout_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[rollout_offset:] + rollout_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[rollout_offset:] + if args.use_critic: + critic_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[critic_offset:] + critic_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[critic_offset:] + + return { + "actor": (pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids), + "critic": (pg, critic_pg_reordered_bundle_indices, critic_pg_reordered_gpu_ids) if args.use_critic else None, + "rollout": (pg, rollout_pg_reordered_bundle_indices, rollout_pg_reordered_gpu_ids), + } + + +def allocate_train_group(args, num_nodes, num_gpus_per_node, pg): + return RayTrainGroup( + args=args, + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + pg=pg, + num_gpus_per_actor=0.4, + ) + + +def create_training_models(args, pgs, rollout_manager): + actor_model = allocate_train_group( + args=args, + num_nodes=args.actor_num_nodes, + num_gpus_per_node=args.actor_num_gpus_per_node, + pg=pgs["actor"], + ) + if args.use_critic: + critic_model = allocate_train_group( + args=args, + num_nodes=args.critic_num_nodes, + num_gpus_per_node=args.critic_num_gpus_per_node, + pg=pgs["critic"], + ) + critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False) + else: + critic_model = None + + start_rollout_ids = ray.get( + actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss) + ) + + assert len(set(start_rollout_ids)) == 1 + if args.start_rollout_id is None: + args.start_rollout_id = start_rollout_ids[0] + + if args.use_critic: + ray.get(critic_init_handle) + actor_model.connect(critic_model) + + actor_model.set_rollout_manager(rollout_manager) + if args.rollout_global_dataset: + ray.get(rollout_manager.load.remote(args.start_rollout_id - 1)) + + return actor_model, critic_model + + +def create_rollout_manager(args, pg): + rollout_manager = RolloutManager.options( + num_cpus=1, + num_gpus=0, + ).remote(args, pg) + + # calculate num_rollout from num_epoch + num_rollout_per_epoch = None + if args.num_rollout is None: + num_rollout_per_epoch = ray.get(rollout_manager.get_num_rollout_per_epoch.remote()) + args.num_rollout = num_rollout_per_epoch * args.num_epoch + assert args.num_rollout > 0 + + if args.check_weight_update_equal: + ray.get(rollout_manager.check_weights.remote(action="snapshot")) + ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) + + if args.offload_rollout: + ray.get(rollout_manager.offload.remote()) + + return rollout_manager, num_rollout_per_epoch diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py new file mode 100644 index 0000000000..f82e53c7af --- /dev/null +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -0,0 +1,1048 @@ +import dataclasses +import itertools +import logging +import multiprocessing +import os +import random +import time +from pathlib import Path +from typing import Any + +import numpy as np +import ray +import torch +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS + +from miles.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig +from miles.backends.sglang_utils.sglang_engine import SGLangEngine +from miles.rollout.base_types import ( + RolloutFnConstructorInput, + RolloutFnEvalInput, + RolloutFnTrainInput, + call_rollout_fn, +) +from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function +from miles.utils import tracking_utils +from miles.utils.environ import enable_experimental_rollout_refactor +from miles.utils.health_monitor import RolloutHealthMonitor +from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client +from miles.utils.iter_utils import group_by +from miles.utils.logging_utils import configure_logger +from miles.utils.metric_checker import MetricChecker +from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix +from miles.utils.misc import load_function +from miles.utils.ray_utils import Box +from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions +from miles.utils.tracking_utils import init_tracking +from miles.utils.types import Sample + +from ..utils.metric_utils import has_repetition +from .utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, Lock + +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("httpcore").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# ServerGroup / RolloutServer abstractions +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class ServerGroup: + """A group of homogeneous SGLang engines with the same configuration. + + All engines in a group share the same tp_size / nodes_per_engine / pg. + A RolloutServer may contain multiple ServerGroups (e.g. prefill vs decode + in PD disaggregation). + """ + + args: Any + pg: Any # (placement_group, reordered_bundle_indices, reordered_gpu_ids) + all_engines: list + num_gpus_per_engine: int + num_new_engines: int + worker_type: str = "regular" # "regular", "prefill", or "decode" + rank_offset: int = 0 + gpu_offset: int = 0 + sglang_overrides: dict = dataclasses.field(default_factory=dict) + needs_offload: bool = False + model_path: str | None = None + router_ip: str | None = None + router_port: int | None = None + + @property + def nodes_per_engine(self): + return max(1, self.num_gpus_per_engine // self.args.num_gpus_per_node) + + @property + def engines(self): + """Node-0 engines only (for multi-node serving).""" + return self.all_engines[:: self.nodes_per_engine] + + @num_new_engines.setter + def num_new_engines(self, value): + self.num_new_engines = value + + @property + def engine_gpu_counts(self) -> list[int]: + """Per-engine GPU count for all node-0 engines, parallel to ``engines``.""" + return [self.num_gpus_per_engine for _ in self.engines] + + @property + def engine_gpu_offsets(self) -> list[int]: + offsets = [] + for j in range(len(self.engines)): + offsets.append(self.gpu_offset + j * self.num_gpus_per_engine) + return offsets + + def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[list, dict[int, int]]: + """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. + + Returns ``(init_handles, port_cursors)`` where *init_handles* is a list + of Ray ObjectRefs and *port_cursors* maps node index -> next free port. + """ + if port_cursors is None: + port_cursors = {} + if self.args.debug_train_only or self.worker_type == "placeholder": + self.num_new_engines = 0 + return [], port_cursors + + num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) + + pg, reordered_bundle_indices, reordered_gpu_ids = self.pg + + RolloutRayActor = ray.remote(SGLangEngine) + + rollout_engines = [] + for i in range(len(self.all_engines)): + if self.all_engines[i] is not None: + continue + + global_rank = self.rank_offset + i + num_gpus = 0.2 + num_cpus = num_gpus + + gpu_index = self.gpu_offset + i * num_gpu_per_engine + base_gpu_id = int(reordered_gpu_ids[gpu_index]) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[gpu_index], + ) + + env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} | { + key: os.environ.get(key, default_val) + for key, default_val in { + "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "false", + "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true", + "SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK": "true", + "SGLANG_MEMORY_SAVER_CUDA_GRAPH": "true", + "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_FALLBACK_VARIANT": "true", + "SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION": "false", + "SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE": "false", + }.items() + } + + rollout_engine = RolloutRayActor.options( + num_cpus=num_cpus, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + runtime_env={ + "env_vars": env_vars, + }, + ).remote( + self.args, + rank=global_rank, + worker_type=self.worker_type, + base_gpu_id=base_gpu_id, + sglang_overrides=self.sglang_overrides, + num_gpus_per_engine=self.num_gpus_per_engine, + ) + + rollout_engines.append((global_rank, rollout_engine)) + self.all_engines[i] = rollout_engine + + self.num_new_engines = len(rollout_engines) + + if self.num_new_engines == 0: + return [], port_cursors + + base_port = max(port_cursors.values()) if port_cursors else 15000 + addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( + args=self.args, + rollout_engines=rollout_engines, + worker_type=self.worker_type, + num_gpus_per_engine=self.num_gpus_per_engine, + rank_offset=self.rank_offset, + base_port=base_port, + ) + + init_handles = [ + engine.init.remote( + **(addr_and_ports[rank]), + router_ip=self.router_ip, + router_port=self.router_port, + ) + for rank, engine in rollout_engines + ] + return init_handles, port_cursors + + def recover(self): + """Recover dead engines across all active groups, overlapping init.""" + dead_per_group = [[i for i, engine in enumerate(g.all_engines) if engine is None] for g in self.server_groups] + + all_handles = [] + port_cursors: dict[int, int] = {} + for g in self.server_groups: + handles, port_cursors = g.start_engines(port_cursors) + all_handles.extend(handles) + if all_handles: + ray.get(all_handles) + + release_handles = [] + updatable_new_engines = [] + non_updatable_groups_engines: list[tuple[str, list]] = [] + for g, dead_indices in zip(self.server_groups, dead_per_group, strict=True): + logger.info(f"Recovered {g.num_new_engines} dead rollout engines (worker_type={g.worker_type})") + assert g.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" + if g.needs_offload and dead_indices: + new_engines = [g.all_engines[i] for i in dead_indices] + release_handles.extend(engine.release_memory_occupation.remote() for engine in new_engines) + if self.update_weights: + updatable_new_engines.extend(new_engines) + elif g.model_path: + non_updatable_groups_engines.append((g.model_path, new_engines)) + + if release_handles: + ray.get(release_handles) + all_resume_engines = updatable_new_engines[:] + for _model_path, engines in non_updatable_groups_engines: + all_resume_engines.extend(engines) + if all_resume_engines: + ray.get( + [ + engine.resume_memory_occupation.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + for engine in all_resume_engines + ] + ) + + def offload(self): + if not self.needs_offload: + return [] + return [engine.release_memory_occupation.remote() for engine in self.engines if engine is not None] + + def onload(self, tags: list[str] | None = None): + if not self.needs_offload: + return [] + return [engine.resume_memory_occupation.remote(tags=tags) for engine in self.engines if engine is not None] + + def onload_weights(self): + if not self.needs_offload: + return + handles = self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + return ray.get(handles) if handles else [] + + def onload_kv(self): + handles = self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH]) + return ray.get(handles) if handles else [] + + def onload_weights_from_disk(self): + """Reload weights from ``model_path`` for non-updatable groups.""" + if not self.needs_offload or not self.model_path: + return [] + return [ + engine.update_weights_from_disk.remote(self.model_path) for engine in self.engines if engine is not None + ] + +# --------------------------------------------------------------------------- +# RolloutManager +# --------------------------------------------------------------------------- + +@ray.remote +class RolloutManager: + """The class to run rollout and convert rollout data to training data.""" + + def __init__(self, args, pg): + configure_logger() + + self.pg = pg + self.args = args + # TODO make args immutable + init_tracking(args, primary=False, router_addr=f"http://{args.sglang_router_ip}:{args.sglang_router_port}") + + self.generate_rollout = load_function(self.args.rollout_function_path) + self.eval_generate_rollout = load_function(self.args.eval_function_path) + self.custom_reward_post_process_func = None + if self.args.custom_reward_post_process_path is not None: + self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) + self.custom_convert_samples_to_train_data_func = None + if self.args.custom_convert_samples_to_train_data_path is not None: + self.custom_convert_samples_to_train_data_func = load_function( + self.args.custom_convert_samples_to_train_data_path + ) + logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.") + logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") + + if self.args.debug_train_only: + self.server_group = None + else: + init_http_client(args) + self.server_group = start_rollout_servers(args, pg) + + self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() + self.rollout_id = -1 + + self._metric_checker = MetricChecker.maybe_create(args) + + def dispose(self): + if self._metric_checker is not None: + self._metric_checker.dispose() + + @property + def rollout_engines(self): + """All node-0 engines across all servers / models.""" + return [e for e in self.server_group.engines] + + def get_updatable_engines_and_lock(self): + """Return engines eligible for weight updates.""" + server_group = self.server_group + engines = server_group.engines if server_group else [] + gpu_counts = server_group.engine_gpu_counts if server_group else [] + gpu_offsets = server_group.engine_gpu_offsets if server_group else [] + num_new = server_group.num_new_engines if server_group else 0 + return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets + + def generate(self, rollout_id): + start_time = time.time() + self.rollout_id = rollout_id + self.health_monitoring_resume() + data, metrics = self._get_rollout_data(rollout_id=rollout_id) + self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False) + _log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time) + data = self._convert_samples_to_train_data(data) + return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"]) + + def eval(self, rollout_id): + if self.args.debug_train_only: + # if debug train only, we don't generate evaluation data + return + self.health_monitoring_resume() + + if self.use_experimental_refactor: + result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) + else: + result = call_rollout_fn( + self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True + ) + data = result.data + self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) + metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) + if self._metric_checker is not None: + self._metric_checker.on_eval(metrics) + + def offload(self, tags: list[str] | None = None): + if tags is not None: + handles = [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.rollout_engines + if engine is not None + ] + return ray.get(handles) if handles else [] + self.server_group.offload() + + + def onload(self, tags: list[str] | None = None): + self.server_group.onload(tags) + + def onload_weights(self): + self.server_group.onload_weights() + + def onload_kv(self): + self.server_group.onload_kv() + + def recover_updatable_engines(self): + """Restart any dead rollout engines and update num_new_engines for update_weights detection. + + Recovers the updatable model (the one that receives weight + updates from training). + """ + server_group = self.server_group + if self.rollout_id == -1 or server_group is None: + engines = server_group.engines if server_group else [] + gpu_counts = server_group.engine_gpu_counts if server_group else [] + gpu_offsets = server_group.engine_gpu_offsets if server_group else [] + return engines, self.rollout_engine_lock, (server_group.num_new_engines if server_group else 0), gpu_counts, gpu_offsets + + server_group.recover() + return ( + server_group.engines, + self.rollout_engine_lock, + server_group.num_new_engines, + server_group.engine_gpu_counts, + server_group.engine_gpu_offsets, + ) + + def clear_updatable_num_new_engines(self): + # when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights + if self.server_group: + self.server_group.num_new_engines = 0 + + def check_weights(self, action: str): + return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + + def _get_rollout_data(self, rollout_id): + if self.args.load_debug_rollout_data: + data = torch.load( + self.args.load_debug_rollout_data.format(rollout_id=rollout_id), + weights_only=False, + )["samples"] + data = [Sample.from_dict(sample) for sample in data] + if (ratio := self.args.load_debug_rollout_data_subsample) is not None: + original_num_rows = len(data) + rough_subsample_num_rows = int(original_num_rows * ratio) + data = data[: rough_subsample_num_rows // 2] + data[-rough_subsample_num_rows // 2 :] + logger.info( + f"Subsample loaded debug rollout data using {ratio=} and change num rows {original_num_rows} -> {len(data)}" + ) + metrics = None + else: + if self.use_experimental_refactor: + data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) + else: + data = call_rollout_fn( + self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False + ) + metrics = data.metrics + data = data.samples + # flatten the data if it is a list of lists + while isinstance(data[0], list): + data = list(itertools.chain.from_iterable(data)) + + if not self.args.disable_rollout_trim_samples: + global_batch_size = self.args.global_batch_size + if self.args.use_dynamic_global_batch_size: + logger.info(f"Collected {len(data)} samples from rollout to train with dynamic global batch size") + # TODO: this is a temporary solution, we should directly save dynamic_global_batch_size to rollout data + self._dynamic_global_batch_size = self._compute_dynamic_global_batch_size(len(data)) + global_batch_size = self._dynamic_global_batch_size + + if len(data) % global_batch_size != 0: + trim_len = (len(data) // global_batch_size) * global_batch_size + if trim_len == 0: + raise ValueError(f"Not enough samples {len(data)} for global_batch_size {global_batch_size}") + origin_data_length = len(data) + data = data[:trim_len] + logger.info(f"trim number of samples from {origin_data_length} to {trim_len}") + logger.info(f"Final collected {len(data)} samples from rollout to train") + + return data, metrics + + def _compute_dynamic_global_batch_size(self, num_samples: int) -> int: + """Calculate dynamic global_batch_size to ensure only one training step. + + Strategy: global_batch_size = num_samples rounded down to a multiple of dp_size + This ensures num_steps_per_rollout = num_samples // global_batch_size = 1 + """ + dp_size = self.train_parallel_config["dp_size"] + original_gbs = self.args.global_batch_size + + # Round down to a multiple of dp_size to ensure only one training step + dynamic_gbs = (num_samples // dp_size) * dp_size + + if dynamic_gbs == 0: + # Too few samples, use at least dp_size + dynamic_gbs = dp_size + logger.warning(f"num_samples={num_samples} < dp_size={dp_size}, using dp_size as global_batch_size") + + # Calculate how many samples will be discarded + wasted = num_samples - dynamic_gbs + + if dynamic_gbs != original_gbs or wasted > 0: + logger.info( + f"Dynamic global_batch_size: {original_gbs} -> {dynamic_gbs} " + f"(num_samples={num_samples}, dp_size={dp_size}, " + f"num_steps=1, wasted={wasted})" + ) + + return dynamic_gbs + + def _save_debug_rollout_data(self, data, rollout_id, evaluation: bool): + # TODO to be refactored (originally Buffer._set_data) + if (path_template := self.args.save_debug_rollout_data) is not None: + path = Path(path_template.format(rollout_id=("eval_" if evaluation else "") + str(rollout_id))) + logger.info(f"Save debug rollout data to {path}") + path.parent.mkdir(parents=True, exist_ok=True) + + # TODO may improve the format + if evaluation: + dump_data = dict( + samples=[sample.to_dict() for dataset_name, info in data.items() for sample in info["samples"]] + ) + else: + dump_data = dict( + samples=[sample.to_dict() for sample in data], + ) + + torch.save(dict(rollout_id=rollout_id, **dump_data), path) + + def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): + if self.custom_reward_post_process_func is not None: + return self.custom_reward_post_process_func(self.args, samples) + + raw_rewards = [sample.get_reward_value(self.args) for sample in samples] + if ( + self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] + and self.args.rewards_normalization + ): + # group norm + rewards = torch.tensor(raw_rewards, dtype=torch.float) + if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size: + rewards = rewards.reshape(-1, self.args.n_samples_per_prompt) + else: + # when samples count are not equal in each group + rewards = rewards.view(-1, rewards.shape[-1]) + mean = rewards.mean(dim=-1, keepdim=True) + rewards = rewards - mean + + if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization: + std = rewards.std(dim=-1, keepdim=True) + rewards = rewards / (std + 1e-6) + + return raw_rewards, rewards.flatten().tolist() + + return raw_rewards, raw_rewards + + def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]): + """ + Convert inference generated samples to training data. + """ + if self.custom_convert_samples_to_train_data_func is not None: + return self.custom_convert_samples_to_train_data_func(self.args, samples) + + raw_rewards, rewards = self._post_process_rewards(samples) + + assert len(raw_rewards) == len(samples) + assert len(rewards) == len(samples) + + train_data = { + "tokens": [sample.tokens for sample in samples], + "response_lengths": [sample.response_length for sample in samples], + # some reward model, e.g. remote rm, may return multiple rewards, + # we could use key to select the reward. + "rewards": rewards, + "raw_reward": raw_rewards, + "truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples], + "sample_indices": [sample.index for sample in samples], + } + + # loss mask + # TODO: compress the loss mask + loss_masks = [] + for sample in samples: + # always instantiate loss_mask if not provided + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + + assert ( + len(sample.loss_mask) == sample.response_length + ), f"loss mask length {len(sample.loss_mask)} != response length {sample.response_length}" + if sample.remove_sample: + sample.loss_mask = [0] * sample.response_length + loss_masks.append(sample.loss_mask) + train_data["loss_masks"] = loss_masks + + # overwriting the raw reward + if samples[0].metadata and "raw_reward" in samples[0].metadata: + train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples] + + # For rollout buffer + if samples[0].metadata and "round_number" in samples[0].metadata: + train_data["round_number"] = [sample.metadata["round_number"] for sample in samples] + + # Add rollout log probabilities for off-policy correction + if samples[0].rollout_log_probs is not None: + train_data["rollout_log_probs"] = [sample.rollout_log_probs for sample in samples] + + if samples[0].rollout_routed_experts is not None: + train_data["rollout_routed_experts"] = [sample.rollout_routed_experts for sample in samples] + + if samples[0].train_metadata is not None: + train_data["metadata"] = [sample.train_metadata for sample in samples] + + if any(sample.multimodal_train_inputs is not None for sample in samples): + train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] + + if "teacher_log_probs" in samples[0].__dict__: + train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] + + return train_data + + def set_train_parallel_config(self, config: dict): + self.train_parallel_config = config + + def _split_train_data_by_dp(self, data, dp_size): + """Split the train data by data parallel size.""" + rollout_data = {} + + if "prompt" in data: + rollout_data["prompt"] = data["prompt"] + + total_lengths = [len(t) for t in data["tokens"]] + data["total_lengths"] = total_lengths + + if self.args.balance_data: + partitions = get_seqlen_balanced_partitions(total_lengths, dp_size, equal_size=True) + else: + partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)] + + rollout_data_refs = [] + + for i in range(dp_size): + rollout_data = {} + partition = partitions[i] + rollout_data["partition"] = partition + for key in [ + "tokens", + "multimodal_train_inputs", + "response_lengths", + "rewards", + "truncated", + "loss_masks", + "round_number", + "sample_indices", + "rollout_log_probs", + "rollout_routed_experts", + "prompt", + "teacher_log_probs", + ]: + if key not in data: + continue + val = [data[key][j] for j in partition] + rollout_data[key] = val + # keys that need to be splited at train side + for key in [ + "raw_reward", + "total_lengths", + ]: + if key not in data: + continue + rollout_data[key] = data[key] + # Pass dynamic global_batch_size to training side + if hasattr(self, "_dynamic_global_batch_size"): + rollout_data["dynamic_global_batch_size"] = self._dynamic_global_batch_size + rollout_data_refs.append(Box(ray.put(rollout_data))) + return rollout_data_refs + + +# --------------------------------------------------------------------------- +# Port allocation helpers +# --------------------------------------------------------------------------- + +def _allocate_rollout_engine_addr_and_ports_normal( + *, + args, + rollout_engines, + worker_type="regular", + num_gpus_per_engine=None, + rank_offset=0, + base_port=15000, +): + # get ports + # there are 4 ports we need to allocate + # 1. server port + # 2. nccl port + # 3. dist_init_addr port + # 4. other ports for dp_attention, which is of size 4 + dp_size + _gpus_per_engine = num_gpus_per_engine or args.rollout_num_gpus_per_engine + num_engines_per_node = max(1, args.num_gpus_per_node // _gpus_per_engine) + addr_and_ports: dict[int, dict] = {} + + # Track per-node port cursors so that different server groups (called + # sequentially) never race for the same ports on a given node. + node_port_cursor: dict[int, int] = {} + + visited_nodes = set() + for rank, engine in rollout_engines: + local_rank = rank - rank_offset + node_index = local_rank // num_engines_per_node + if node_index in visited_nodes: + continue + visited_nodes.add(node_index) + # TODO: currently when restarting engines, we will set port for all engines on this node starting with this rank. + # e.g. for 8 gpus, if we are restarting engine on gpu 3, we will set port for engine 3,4,5,6,7 on this node. + num_engines_on_this_node = num_engines_per_node - (local_rank % num_engines_per_node) + + def get_addr_and_ports(engine, node_idx): + # use small ports to prevent ephemeral port between 32768 and 65536. + # also, ray uses port 10002-19999, thus we avoid near-10002 to avoid racing condition + start_port = node_port_cursor.get(node_idx, base_port) + + def port(consecutive=1): + nonlocal start_port + _, port = ray.get( + engine._get_current_node_ip_and_free_port.remote( + start_port=start_port, + consecutive=consecutive, + ) + ) + start_port = port + consecutive + node_port_cursor[node_idx] = start_port + return port + + def addr(): + addr, _ = ray.get(engine._get_current_node_ip_and_free_port.remote()) + return addr + + return addr, port + + get_addr, get_port = get_addr_and_ports(engine, node_index) + + for i in range(num_engines_on_this_node): + current_rank = rank + i + addr_and_ports.setdefault(current_rank, {}) + addr_and_ports[current_rank]["host"] = get_addr() + addr_and_ports[current_rank]["port"] = get_port() + addr_and_ports[current_rank]["nccl_port"] = get_port() + + if worker_type == "prefill": + addr_and_ports[current_rank]["disaggregation_bootstrap_port"] = get_port() + + if _gpus_per_engine > args.num_gpus_per_node: + num_node_per_engine = _gpus_per_engine // args.num_gpus_per_node + if local_rank % num_node_per_engine == 0: + dist_init_addr = f"{get_addr()}:{get_port(30 + args.sglang_dp_size)}" + for i in range(num_node_per_engine): + addr_and_ports.setdefault(rank + i, {}) + addr_and_ports[rank + i]["dist_init_addr"] = dist_init_addr + else: + for i in range(num_engines_on_this_node): + addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(30 + args.sglang_dp_size)}" + + for i, _ in rollout_engines: + for key in ["port", "nccl_port", "dist_init_addr"]: + assert key in addr_and_ports[i], f"Engine {i} {key} is not set." + logger.info(f"Ports for engine {i}: {addr_and_ports[i]}") + + return addr_and_ports, node_port_cursor + +# --------------------------------------------------------------------------- +# Router + server bootstrap +# --------------------------------------------------------------------------- + +def _start_router(args, *, force_new: bool = False) -> tuple[str, int]: + """Start sgl router or miles router and return (router_ip, router_port). + + If ``args.sglang_router_ip`` is already set and ``force_new`` is False, + skip launching and return the existing values. + """ + if not force_new and args.sglang_router_ip is not None: + return args.sglang_router_ip, args.sglang_router_port + + router_ip = _wrap_ipv6(get_host_info()[1]) + if force_new: + router_port = find_available_port(random.randint(3000, 4000)) + else: + router_port = args.sglang_router_port + if router_port is None: + router_port = find_available_port(random.randint(3000, 4000)) + + from sglang_router.launch_router import RouterArgs + + from miles.utils.http_utils import run_router + + router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) + router_args.host = router_ip + router_args.port = router_port + router_args.prometheus_port = find_available_port(random.randint(4000, 5000)) + router_args.log_level = "warn" + router_args.request_timeout_secs = args.sglang_router_request_timeout_secs + + logger.info(f"Launch router with args: {router_args}") + + process = multiprocessing.Process( + target=run_router, + args=(router_args,), + ) + process.daemon = True + process.start() + time.sleep(3) + assert process.is_alive() + logger.info(f"Router launched at {router_ip}:{router_port}") + return router_ip, router_port + +def _compute_rollout_offset(args) -> int: + """Offset (in PG bundle slots) where rollout GPUs start.""" + if args.debug_train_only or args.debug_rollout_only or args.colocate: + return 0 + if getattr(args, "critic_train_only", False): + return args.critic_num_nodes * args.critic_num_gpus_per_node + offset = args.actor_num_nodes * args.actor_num_gpus_per_node + if getattr(args, "use_critic", False): + offset += args.critic_num_nodes * args.critic_num_gpus_per_node + return offset + + +def _compute_megatron_num_gpus(args) -> int: + """Total number of megatron (actor + critic) GPU slots in the placement group.""" + if getattr(args, "debug_rollout_only", False): + return 0 + if getattr(args, "critic_train_only", False): + return args.critic_num_nodes * args.critic_num_gpus_per_node + num = args.actor_num_nodes * args.actor_num_gpus_per_node + if getattr(args, "use_critic", False): + num += args.critic_num_nodes * args.critic_num_gpus_per_node + return num + + +def start_rollout_servers(args, pg) -> ServerGroup: + """Start rollout servers: one per model, each with its own router. + + Returns a dict mapping model name -> ``RolloutServer``. + """ + config = _resolve_sglang_config(args) + + server_group: ServerGroup = None + gpu_offset = 0 + engine_offset = 0 + + rollout_pg_offset = _compute_rollout_offset(args) + megatron_num_gpus = _compute_megatron_num_gpus(args) + + model_cfg.resolve(args) + router_ip, router_port = _start_router(args, force_new=False) + args.sglang_router_ip = router_ip + args.sglang_router_port = router_port + + server_groups: list[ServerGroup] = [] + all_init_handles: list = [] + port_cursors: dict[int, int] = {} + + for group_cfg in model_cfg.server_groups: + gpus_per_engine = group_cfg.num_gpus_per_engine + num_gpu_per_engine_local = min(gpus_per_engine, args.num_gpus_per_node) + num_engines = group_cfg.num_gpus // num_gpu_per_engine_local + + group_abs_start = rollout_pg_offset + gpu_offset + needs_offload = args.offload_rollout and group_abs_start < megatron_num_gpus + overrides = dict(group_cfg.overrides) + if args.offload_rollout and not needs_offload: + overrides.setdefault("enable_memory_saver", False) + logger.info( + f"Engine group '{group_cfg.worker_type}' gpu_offset={gpu_offset} " + f"(abs={group_abs_start}): needs_offload={needs_offload}" + ) + + group = ServerGroup( + args=args, + pg=pg, + all_engines=[None] * num_engines if group_cfg.worker_type != "placeholder" else [], + num_gpus_per_engine=gpus_per_engine, + num_new_engines=0, + worker_type=group_cfg.worker_type, + rank_offset=engine_offset, + gpu_offset=gpu_offset, + sglang_overrides=overrides, + needs_offload=needs_offload, + model_path=overrides.get("model_path", args.hf_checkpoint), + router_ip=router_ip, + router_port=router_port, + ) + handles, port_cursors = group.start_engines(port_cursors) + all_init_handles.extend(handles) + server_groups.append(group) + + engine_offset += num_engines + gpu_offset += group_cfg.num_gpus + + if all_init_handles: + ray.get(all_init_handles) + + servers[model_cfg.name] = RolloutServer( + server_groups=server_groups, + router_ip=router_ip, + router_port=router_port, + model_name=model_cfg.name, + update_weights=model_cfg.update_weights, + ) + + args.sglang_model_routers = {name: (srv.router_ip, srv.router_port) for name, srv in servers.items()} + + return servers + + +def _resolve_sglang_config(args) -> SglangConfig: + """Build a SglangConfig from args, choosing the right source.""" + if getattr(args, "sglang_config", None) is not None: + config = SglangConfig.from_yaml(args.sglang_config) + expected = args.rollout_num_gpus + actual = config.total_num_gpus + assert actual == expected, f"sglang_config total GPUs ({actual}) != rollout_num_gpus ({expected})" + return config + + if args.prefill_num_servers is not None: + return SglangConfig.from_prefill_num_servers(args) + + return SglangConfig( + models=[ + ModelConfig( + name="default", + server_groups=[ServerGroupConfig(worker_type="regular", num_gpus=args.rollout_num_gpus)], + ) + ] + ) + + +# --------------------------------------------------------------------------- +# Logging / metrics helpers (unchanged) +# --------------------------------------------------------------------------- + + +def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] | None = None): + if args.custom_eval_rollout_log_function_path is not None: + custom_log_func = load_function(args.custom_eval_rollout_log_function_path) + if custom_log_func(rollout_id, args, data, extra_metrics): + return + + log_dict = extra_metrics or {} + for key in data.keys(): + rewards = data[key]["rewards"] + log_dict[f"eval/{key}"] = sum(rewards) / len(rewards) + if (samples := data[key].get("samples")) is not None: + log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), f"eval/{key}/") + if "truncated" in data[key]: + truncated = data[key]["truncated"] + log_dict[f"eval/{key}-truncated_ratio"] = sum(truncated) / len(truncated) + if args.log_passrate: + log_dict |= dict_add_prefix( + compute_pass_rate( + flat_rewards=rewards, + group_size=args.n_samples_per_eval_prompt, + ), + f"eval/{key}-", + ) + + logger.info(f"eval {rollout_id}: {log_dict}") + + step = compute_rollout_step(args, rollout_id) + log_dict["eval/step"] = step + tracking_utils.log(args, log_dict, step_key="eval/step") + + return log_dict + + +def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time): + if args.custom_rollout_log_function_path is not None: + custom_log_func = load_function(args.custom_rollout_log_function_path) + if custom_log_func(rollout_id, args, samples, rollout_extra_metrics, rollout_time): + return + + if args.load_debug_rollout_data: + return + + log_dict = {**(rollout_extra_metrics or {})} + log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/") + log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/") + logger.info(f"perf {rollout_id}: {log_dict}") + step = compute_rollout_step(args, rollout_id) + log_dict["rollout/step"] = step + tracking_utils.log(args, log_dict, step_key="rollout/step") + + +def compute_metrics_from_samples(args, samples): + response_lengths = [sample.effective_response_length for sample in samples] + + log_dict = {} + log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") + log_dict |= _compute_zero_std_metrics(args, samples) + log_dict |= _compute_spec_metrics(args, samples) + log_dict |= _compute_prefix_cache_metrics(args, samples) + log_dict |= _compute_reward_cat_metrics(args, samples) + log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() + log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() + return log_dict + + +def compute_perf_metrics_from_samples(args, samples, rollout_time): + non_generation_time = [sample.non_generation_time for sample in samples] + + log_dict = {} + log_dict["rollout_time"] = rollout_time + if max(non_generation_time) > 0: + log_dict |= dict_add_prefix(compute_statistics(non_generation_time), "non_generation_time/") + + def token_perf(response_lengths, non_generation_time, key=""): + max_response_length = max(response_lengths) + if args.rollout_num_gpus: + log_dict[f"{key}tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus + log_dict[f"longest_{key}sample_tokens_per_sec"] = max_response_length / rollout_time + + if max(non_generation_time) == 0: + return + + non_generation_time = [ + t for t, length in zip(non_generation_time, response_lengths, strict=True) if length == max_response_length + ] + mean_non_generation_time = sum(non_generation_time) / len(non_generation_time) + + log_dict[f"longest_{key}sample_non_generation_time"] = mean_non_generation_time + log_dict[f"longest_{key}sample_tokens_per_sec_without_non_generation"] = max_response_length / ( + rollout_time - mean_non_generation_time + ) + + token_perf([sample.response_length for sample in samples], non_generation_time, key="") + token_perf([sample.effective_response_length for sample in samples], non_generation_time, key="effective_") + + return log_dict + + +def _compute_zero_std_metrics(args, all_samples: list[Sample]): + # only compute in GRPO-like algorithms where one prompt has multiple responses + if args.advantage_estimator == "ppo": + return {} + + def _is_zero_std(samples: list[Sample]): + rewards = [sample.get_reward_value(args) for sample in samples] + return len(rewards) == 0 or all(rewards[0] == r for r in rewards) + + all_sample_groups = group_by(all_samples, lambda s: s.group_index) + interesting_sample_groups = [g for g in all_sample_groups.values() if _is_zero_std(g)] + + interesting_rewards = [str(round(g[0].get_reward_value(args), 1)) for g in interesting_sample_groups] + + return {f"zero_std/count_{reward}": len(items) for reward, items in group_by(interesting_rewards).items()} + + +def _compute_spec_metrics(args, all_samples: list[Sample]): + if args.sglang_speculative_algorithm is None: + return {} + num_samples = len(all_samples) + metrics = {} + metrics["spec_accept_rate"] = sum(sample.spec_info.spec_accept_rate for sample in all_samples) / num_samples + metrics["spec_accept_length"] = sum(sample.spec_info.spec_accept_length for sample in all_samples) / num_samples + return metrics + + +def _compute_prefix_cache_metrics(args, all_samples: list[Sample]): + num_samples = len(all_samples) + metrics = {} + total_cached_tokens = sum(sample.prefix_cache_info.cached_tokens for sample in all_samples) + total_prompt_tokens = sum(sample.prefix_cache_info.total_prompt_tokens for sample in all_samples) + + metrics["prefix_cache_hit_rate"] = total_cached_tokens / total_prompt_tokens if total_prompt_tokens > 0 else 0.0 + metrics["avg_cached_tokens_per_sample"] = total_cached_tokens / num_samples + return metrics + + +def _compute_reward_cat_metrics(args, all_samples: list[Sample]): + reward_cat_key = args.log_reward_category + if reward_cat_key is None: + return {} + + samples_of_reward_cat = group_by(all_samples, lambda s: s.reward[reward_cat_key]) + + return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 3ba526d37e..6087dcbb19 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -101,107 +101,6 @@ def _wait_server_healthy(base_url, api_key, is_process_alive): time.sleep(2) -def find_available_port(base_port: int): - port = base_port + random.randint(100, 1000) - while True: - if is_port_available(port): - return port - if port < 60000: - port += 42 - else: - port -= 43 - -def is_port_available(port): - """Return whether a port is available.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("", port)) - s.listen(1) - return True - except OSError: - return False - except OverflowError: - return False - -def get_host_info(): - hostname = socket.gethostname() - - if env_overwrite_local_ip := os.getenv(MILES_HOST_IP_ENV, None): - return hostname, env_overwrite_local_ip - - def _is_loopback(ip): - return ip.startswith("127.") or ip == "::1" - - def _resolve_ip(family, test_target_ip): - """ - Attempt to get the local LAN IP for the specific family (IPv4/IPv6). - Strategy: UDP Probe (Preferred) -> Hostname Resolution (Fallback) -> None - """ - - # Strategy 1: UDP Connect Probe (Most accurate, relies on routing table) - # Useful when the machine has a default gateway or internet access. - try: - with socket.socket(family, socket.SOCK_DGRAM) as s: - # The IP doesn't need to be reachable, but the routing table must exist. - s.connect((test_target_ip, 80)) - ip = s.getsockname()[0] - if not _is_loopback(ip): - return ip - except Exception: - pass # Route unreachable or network error, move to next strategy. - - # Strategy 2: Hostname Resolution (Fallback for offline clusters) - # Useful for offline environments where UDP connect fails but /etc/hosts is configured. - try: - # getaddrinfo allows specifying the family (AF_INET or AF_INET6) - # Result format: [(family, type, proto, canonname, sockaddr), ...] - infos = socket.getaddrinfo(hostname, None, family=family, type=socket.SOCK_STREAM) - - for info in infos: - ip = info[4][0] # The first element of sockaddr is the IP - # Must filter out loopback addresses to avoid "127.0.0.1" issues - if not _is_loopback(ip): - return ip - except Exception: - pass - - return None - - prefer_ipv6 = os.getenv("MILES_PREFER_IPV6", "0").lower() in ("1", "true", "yes", "on") - local_ip = None - final_fallback = "127.0.0.1" - - if prefer_ipv6: - # [Strict Mode] IPv6 Only - # 1. Try UDP V6 Probe - # 2. Try Hostname Resolution (V6) - # If failed, fallback to V6 loopback. Never mix with V4. - local_ip = _resolve_ip(socket.AF_INET6, "2001:4860:4860::8888") - final_fallback = "::1" - else: - # [Strict Mode] IPv4 Only (Default) - # 1. Try UDP V4 Probe - # 2. Try Hostname Resolution (V4) - # If failed, fallback to V4 loopback. Never mix with V6. - local_ip = _resolve_ip(socket.AF_INET, "8.8.8.8") - final_fallback = "127.0.0.1" - - return hostname, local_ip or final_fallback - -def get_current_node_ip(): - address = ray._private.services.get_node_ip_address() - # strip ipv6 address - address = address.strip("[]") - return address - -def get_free_port(start_port=10000, consecutive=1): - # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available - port = start_port - while not all(is_port_available(port + i) for i in range(consecutive)): - port += 1 - return port - class SGLangEngine: def __init__( self, @@ -225,7 +124,6 @@ def init( port, nccl_port, host=None, - disaggregation_bootstrap_port=None, router_ip=None, router_port=None, ): @@ -262,8 +160,6 @@ def _format_v6_uri(addr): nccl_port, host, port, - self.worker_type, - disaggregation_bootstrap_port, base_gpu_id=self.base_gpu_id, sglang_overrides=self.sglang_overrides, num_gpus_per_engine=self.num_gpus_per_engine, @@ -585,8 +481,6 @@ def _compute_server_args( nccl_port, host, port, - worker_type: str = "regular", - disaggregation_bootstrap_port: int | None = None, base_gpu_id: int | None = None, sglang_overrides: dict | None = None, num_gpus_per_engine: int | None = None, diff --git a/nemo_rl/models/generation/redesign/utils.py b/nemo_rl/models/generation/redesign/utils.py new file mode 100644 index 0000000000..a7df1ecc78 --- /dev/null +++ b/nemo_rl/models/generation/redesign/utils.py @@ -0,0 +1,107 @@ + +@ray.remote(num_gpus=1) +class InfoActor: + def get_ip_and_gpu_id(self): + return ray.util.get_node_ip_address(), ray.get_gpu_ids()[0] + + +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except OSError: + return False + except OverflowError: + return False + +def get_host_info(): + hostname = socket.gethostname() + + if env_overwrite_local_ip := os.getenv(MILES_HOST_IP_ENV, None): + return hostname, env_overwrite_local_ip + + def _is_loopback(ip): + return ip.startswith("127.") or ip == "::1" + + def _resolve_ip(family, test_target_ip): + """ + Attempt to get the local LAN IP for the specific family (IPv4/IPv6). + Strategy: UDP Probe (Preferred) -> Hostname Resolution (Fallback) -> None + """ + + # Strategy 1: UDP Connect Probe (Most accurate, relies on routing table) + # Useful when the machine has a default gateway or internet access. + try: + with socket.socket(family, socket.SOCK_DGRAM) as s: + # The IP doesn't need to be reachable, but the routing table must exist. + s.connect((test_target_ip, 80)) + ip = s.getsockname()[0] + if not _is_loopback(ip): + return ip + except Exception: + pass # Route unreachable or network error, move to next strategy. + + # Strategy 2: Hostname Resolution (Fallback for offline clusters) + # Useful for offline environments where UDP connect fails but /etc/hosts is configured. + try: + # getaddrinfo allows specifying the family (AF_INET or AF_INET6) + # Result format: [(family, type, proto, canonname, sockaddr), ...] + infos = socket.getaddrinfo(hostname, None, family=family, type=socket.SOCK_STREAM) + + for info in infos: + ip = info[4][0] # The first element of sockaddr is the IP + # Must filter out loopback addresses to avoid "127.0.0.1" issues + if not _is_loopback(ip): + return ip + except Exception: + pass + + return None + + prefer_ipv6 = os.getenv("MILES_PREFER_IPV6", "0").lower() in ("1", "true", "yes", "on") + local_ip = None + final_fallback = "127.0.0.1" + + if prefer_ipv6: + # [Strict Mode] IPv6 Only + # 1. Try UDP V6 Probe + # 2. Try Hostname Resolution (V6) + # If failed, fallback to V6 loopback. Never mix with V4. + local_ip = _resolve_ip(socket.AF_INET6, "2001:4860:4860::8888") + final_fallback = "::1" + else: + # [Strict Mode] IPv4 Only (Default) + # 1. Try UDP V4 Probe + # 2. Try Hostname Resolution (V4) + # If failed, fallback to V4 loopback. Never mix with V6. + local_ip = _resolve_ip(socket.AF_INET, "8.8.8.8") + final_fallback = "127.0.0.1" + + return hostname, local_ip or final_fallback + +def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + address = address.strip("[]") + return address + +def get_free_port(start_port=10000, consecutive=1): + # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available + port = start_port + while not all(is_port_available(port + i) for i in range(consecutive)): + port += 1 + return port \ No newline at end of file From decb22e6525a5a4b2f27e4ed60b3ca3e1acb2106 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 5 Apr 2026 14:09:20 -0700 Subject: [PATCH 03/36] update --- nemo_rl/models/generation/redesign/config.py | 197 ++++++++++++++++++ .../generation/redesign/sglang_generation.py | 33 ++- .../generation/redesign/sglang_worker.py | 2 +- 3 files changed, 224 insertions(+), 8 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/config.py diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py new file mode 100644 index 0000000000..68168c6087 --- /dev/null +++ b/nemo_rl/models/generation/redesign/config.py @@ -0,0 +1,197 @@ +"""Configuration dataclasses for SGLang engine deployment.""" + +import dataclasses +import logging + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ServerGroupConfig: + """Configuration for a single server group. + + Attributes: + worker_type: One of "regular", "prefill", "decode", or "placeholder". + "placeholder" reserves GPU slots without creating engines. + num_gpus: Total number of GPUs for this group. + num_gpus_per_engine: GPUs per engine for this group. Overrides the + model-level or global ``--rollout-num-gpus-per-engine``. + overrides: Optional dict of SGLang ``ServerArgs`` field overrides. + These are applied on top of the base CLI ``--sglang-*`` + arguments in ``_compute_server_args``. + """ + + worker_type: str + num_gpus: int + num_gpus_per_engine: int | None = None + overrides: dict = dataclasses.field(default_factory=dict) + + def __post_init__(self): + valid_types = {"regular", "prefill", "decode", "placeholder"} + assert ( + self.worker_type in valid_types + ), f"Invalid worker_type '{self.worker_type}', must be one of {valid_types}" + assert self.num_gpus > 0, f"num_gpus must be > 0, got {self.num_gpus}" + + +@dataclasses.dataclass +class ModelConfig: + """Configuration for a single model deployment. + + Attributes: + name: Unique name for this model (e.g. "actor", "reward"). + model_path: HF checkpoint path. Falls back to ``args.hf_checkpoint``. + num_gpus_per_engine: Default GPUs per engine for all groups in this + model. Individual groups can override. + server_groups: Server group configurations for this model. + update_weights: Whether this model receives weight updates from + training. Set to ``False`` for frozen models + (reference, reward, etc.). When ``None`` (default), + automatically inferred in ``resolve()``: ``True`` if + model_path matches ``args.hf_checkpoint``, ``False`` + otherwise. + """ + + name: str + model_path: str | None = None + num_gpus_per_engine: int | None = None + server_groups: list[ServerGroupConfig] = dataclasses.field(default_factory=list) + update_weights: bool | None = None + + def resolve(self, args) -> None: + """Resolve per-group defaults from model-level then args-level values.""" + default_gpus_per_engine = self.num_gpus_per_engine or args.rollout_num_gpus_per_engine + default_model_path = self.model_path or args.hf_checkpoint + for g in self.server_groups: + if g.num_gpus_per_engine is None: + g.num_gpus_per_engine = default_gpus_per_engine + if "model_path" not in g.overrides: + g.overrides["model_path"] = default_model_path + + if self.server_groups: + model_paths = {g.overrides["model_path"] for g in self.server_groups} + assert len(model_paths) == 1, ( + f"Model '{self.name}' has server groups with different model_path values: " + f"{model_paths}. All server groups within a model must use the same model_path." + ) + effective_model_path = model_paths.pop() + else: + effective_model_path = default_model_path + + if self.update_weights is None: + if effective_model_path != args.hf_checkpoint: + logger.warning( + f"Model '{self.name}' uses model_path='{effective_model_path}' which differs " + f"from hf_checkpoint='{args.hf_checkpoint}'. Defaulting update_weights to False. " + f"Set update_weights explicitly in the config to suppress this warning." + ) + self.update_weights = False + else: + self.update_weights = True + + @property + def has_pd_disaggregation(self) -> bool: + return any(g.worker_type in ("prefill", "decode") for g in self.server_groups) + + @property + def total_num_gpus(self) -> int: + return sum(g.num_gpus for g in self.server_groups) + + +@dataclasses.dataclass +class SglangConfig: + """Configuration for SGLang engine deployment. + + Loaded from ``--sglang-config`` YAML file. + + **Config format**:: + + sglang: + - name: actor + model_path: /path/to/actor + update_weights: true # receives training weight updates (default) + num_gpus_per_engine: 2 + server_groups: + - worker_type: prefill + num_gpus: 4 + num_gpus_per_engine: 2 + - worker_type: decode + num_gpus: 8 + num_gpus_per_engine: 4 + - name: ref + model_path: /path/to/ref + update_weights: false # frozen, no weight updates + server_groups: + - worker_type: regular + num_gpus: 4 + + Each model gets its own router. ``placeholder`` groups reserve GPU + slots without creating engines. ``overrides`` are ``ServerArgs`` + field names applied on top of the base ``--sglang-*`` CLI args. + + Set ``update_weights: false`` for frozen models (reference, reward, + etc.) that should not receive weight updates from training. + + .. note:: + + ``engine_groups`` is accepted as a backward-compatible alias for + ``server_groups`` in the YAML config. + """ + + models: list[ModelConfig] + + @staticmethod + def from_yaml(path: str) -> "SglangConfig": + with open(path) as f: + data = yaml.safe_load(f) + + assert "sglang" in data, ( + f"sglang config must have a 'sglang' key, got {list(data.keys())}. " + f"Wrap your server_groups inside a model entry under 'sglang'." + ) + models = [] + for m in data["sglang"]: + raw_groups = m.get("server_groups") or m.get("engine_groups") or [] + groups = [ServerGroupConfig(**g) for g in raw_groups] + models.append( + ModelConfig( + name=m["name"], + model_path=m.get("model_path"), + num_gpus_per_engine=m.get("num_gpus_per_engine"), + server_groups=groups, + update_weights=m.get("update_weights"), + ) + ) + return SglangConfig(models=models) + + @staticmethod + def from_prefill_num_servers(args) -> "SglangConfig": + """Build a config equivalent to the legacy --prefill-num-servers flag.""" + total_gpus = args.rollout_num_gpus + prefill_gpus = args.prefill_num_servers * args.rollout_num_gpus_per_engine + decode_gpus = total_gpus - prefill_gpus + assert decode_gpus > 0, f"No decode GPUs: total {total_gpus}, prefill {prefill_gpus}" + return SglangConfig( + models=[ + ModelConfig( + name="default", + server_groups=[ + ServerGroupConfig(worker_type="prefill", num_gpus=prefill_gpus), + ServerGroupConfig(worker_type="decode", num_gpus=decode_gpus), + ], + ) + ] + ) + + @property + def has_pd_disaggregation(self) -> bool: + return any(m.has_pd_disaggregation for m in self.models) + + @property + def total_num_gpus(self) -> int: + return sum(m.total_num_gpus for m in self.models) + + + diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index f82e53c7af..eed889cba4 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -50,6 +50,21 @@ # ServerGroup / RolloutServer abstractions # --------------------------------------------------------------------------- +# # use_unified_pg = True for Nemo +# if use_unified_pg: +# # Create a single unified placement group for cross-node model parallelism +# all_bundles = [] +# for bundle_count in self._bundle_ct_per_node_list: +# for _ in range(bundle_count): +# all_bundles.append( +# {"CPU": num_cpus_per_bundle, "GPU": num_gpus_per_bundle} +# ) + +# placement_groups = [ +# placement_group( +# bundles=all_bundles, strategy=strategy, name=f"{self.name}-unified" +# ) +# ] @dataclasses.dataclass class ServerGroup: @@ -799,24 +814,29 @@ def _compute_megatron_num_gpus(args) -> int: return num -def start_rollout_servers(args, pg) -> ServerGroup: +def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: """Start rollout servers: one per model, each with its own router. Returns a dict mapping model name -> ``RolloutServer``. """ config = _resolve_sglang_config(args) - server_group: ServerGroup = None + servers: dict[str, RolloutServer] = {} gpu_offset = 0 engine_offset = 0 rollout_pg_offset = _compute_rollout_offset(args) megatron_num_gpus = _compute_megatron_num_gpus(args) - model_cfg.resolve(args) - router_ip, router_port = _start_router(args, force_new=False) - args.sglang_router_ip = router_ip - args.sglang_router_port = router_port + for model_idx, model_cfg in enumerate(config.models): + model_cfg.resolve(args) + + has_pd = model_cfg.has_pd_disaggregation + router_ip, router_port = _start_router(args, has_pd_disaggregation=has_pd, force_new=(model_idx > 0)) + + if model_idx == 0: + args.sglang_router_ip = router_ip + args.sglang_router_port = router_port server_groups: list[ServerGroup] = [] all_init_handles: list = [] @@ -874,7 +894,6 @@ def start_rollout_servers(args, pg) -> ServerGroup: return servers - def _resolve_sglang_config(args) -> SglangConfig: """Build a SglangConfig from args, choosing the right source.""" if getattr(args, "sglang_config", None) is not None: diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 6087dcbb19..dd356657c2 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -495,7 +495,7 @@ def _compute_server_args( "trust_remote_code": True, "random_seed": args.seed + rank, # memory - "enable_memory_saver": args.offload_rollout, + "enable_memfory_saver": args.offload_rollout, # distributed "host": host, "port": port, From a5b2563ce8b0e85dfb3c6b3f8f5a14e014b668d9 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 5 Apr 2026 14:10:43 -0700 Subject: [PATCH 04/36] update --- nemo_rl/models/generation/redesign/sglang_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index dd356657c2..6087dcbb19 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -495,7 +495,7 @@ def _compute_server_args( "trust_remote_code": True, "random_seed": args.seed + rank, # memory - "enable_memfory_saver": args.offload_rollout, + "enable_memory_saver": args.offload_rollout, # distributed "host": host, "port": port, From ce8a1c626fb965e6b7cc091795360aa0030c2091 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 5 Apr 2026 21:03:53 -0700 Subject: [PATCH 05/36] update --- nemo_rl/models/generation/redesign/config.py | 301 ++++++------------ .../generation/redesign/sglang_generation.py | 11 +- .../generation/redesign/sglang_worker.py | 106 +++--- 3 files changed, 161 insertions(+), 257 deletions(-) diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py index 68168c6087..63d51b04ae 100644 --- a/nemo_rl/models/generation/redesign/config.py +++ b/nemo_rl/models/generation/redesign/config.py @@ -1,197 +1,108 @@ -"""Configuration dataclasses for SGLang engine deployment.""" - -import dataclasses -import logging - -import yaml - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class ServerGroupConfig: - """Configuration for a single server group. - - Attributes: - worker_type: One of "regular", "prefill", "decode", or "placeholder". - "placeholder" reserves GPU slots without creating engines. - num_gpus: Total number of GPUs for this group. - num_gpus_per_engine: GPUs per engine for this group. Overrides the - model-level or global ``--rollout-num-gpus-per-engine``. - overrides: Optional dict of SGLang ``ServerArgs`` field overrides. - These are applied on top of the base CLI ``--sglang-*`` - arguments in ``_compute_server_args``. - """ - - worker_type: str - num_gpus: int - num_gpus_per_engine: int | None = None - overrides: dict = dataclasses.field(default_factory=dict) - - def __post_init__(self): - valid_types = {"regular", "prefill", "decode", "placeholder"} - assert ( - self.worker_type in valid_types - ), f"Invalid worker_type '{self.worker_type}', must be one of {valid_types}" - assert self.num_gpus > 0, f"num_gpus must be > 0, got {self.num_gpus}" - - -@dataclasses.dataclass -class ModelConfig: - """Configuration for a single model deployment. - - Attributes: - name: Unique name for this model (e.g. "actor", "reward"). - model_path: HF checkpoint path. Falls back to ``args.hf_checkpoint``. - num_gpus_per_engine: Default GPUs per engine for all groups in this - model. Individual groups can override. - server_groups: Server group configurations for this model. - update_weights: Whether this model receives weight updates from - training. Set to ``False`` for frozen models - (reference, reward, etc.). When ``None`` (default), - automatically inferred in ``resolve()``: ``True`` if - model_path matches ``args.hf_checkpoint``, ``False`` - otherwise. - """ - - name: str - model_path: str | None = None - num_gpus_per_engine: int | None = None - server_groups: list[ServerGroupConfig] = dataclasses.field(default_factory=list) - update_weights: bool | None = None - - def resolve(self, args) -> None: - """Resolve per-group defaults from model-level then args-level values.""" - default_gpus_per_engine = self.num_gpus_per_engine or args.rollout_num_gpus_per_engine - default_model_path = self.model_path or args.hf_checkpoint - for g in self.server_groups: - if g.num_gpus_per_engine is None: - g.num_gpus_per_engine = default_gpus_per_engine - if "model_path" not in g.overrides: - g.overrides["model_path"] = default_model_path - - if self.server_groups: - model_paths = {g.overrides["model_path"] for g in self.server_groups} - assert len(model_paths) == 1, ( - f"Model '{self.name}' has server groups with different model_path values: " - f"{model_paths}. All server groups within a model must use the same model_path." - ) - effective_model_path = model_paths.pop() - else: - effective_model_path = default_model_path - - if self.update_weights is None: - if effective_model_path != args.hf_checkpoint: - logger.warning( - f"Model '{self.name}' uses model_path='{effective_model_path}' which differs " - f"from hf_checkpoint='{args.hf_checkpoint}'. Defaulting update_weights to False. " - f"Set update_weights explicitly in the config to suppress this warning." - ) - self.update_weights = False - else: - self.update_weights = True - - @property - def has_pd_disaggregation(self) -> bool: - return any(g.worker_type in ("prefill", "decode") for g in self.server_groups) - - @property - def total_num_gpus(self) -> int: - return sum(g.num_gpus for g in self.server_groups) - - -@dataclasses.dataclass -class SglangConfig: - """Configuration for SGLang engine deployment. - - Loaded from ``--sglang-config`` YAML file. - - **Config format**:: - - sglang: - - name: actor - model_path: /path/to/actor - update_weights: true # receives training weight updates (default) - num_gpus_per_engine: 2 - server_groups: - - worker_type: prefill - num_gpus: 4 - num_gpus_per_engine: 2 - - worker_type: decode - num_gpus: 8 - num_gpus_per_engine: 4 - - name: ref - model_path: /path/to/ref - update_weights: false # frozen, no weight updates - server_groups: - - worker_type: regular - num_gpus: 4 - - Each model gets its own router. ``placeholder`` groups reserve GPU - slots without creating engines. ``overrides`` are ``ServerArgs`` - field names applied on top of the base ``--sglang-*`` CLI args. - - Set ``update_weights: false`` for frozen models (reference, reward, - etc.) that should not receive weight updates from training. - - .. note:: - - ``engine_groups`` is accepted as a backward-compatible alias for - ``server_groups`` in the YAML config. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, NotRequired, TypedDict + +from nemo_rl.models.generation.interfaces import GenerationConfig + + +class SglangSpecificArgs(TypedDict): + """SGLang-specific configuration arguments. + + Most fields below map directly to SGLang's ServerArgs (see: + https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py). """ - models: list[ModelConfig] - - @staticmethod - def from_yaml(path: str) -> "SglangConfig": - with open(path) as f: - data = yaml.safe_load(f) - - assert "sglang" in data, ( - f"sglang config must have a 'sglang' key, got {list(data.keys())}. " - f"Wrap your server_groups inside a model entry under 'sglang'." - ) - models = [] - for m in data["sglang"]: - raw_groups = m.get("server_groups") or m.get("engine_groups") or [] - groups = [ServerGroupConfig(**g) for g in raw_groups] - models.append( - ModelConfig( - name=m["name"], - model_path=m.get("model_path"), - num_gpus_per_engine=m.get("num_gpus_per_engine"), - server_groups=groups, - update_weights=m.get("update_weights"), - ) - ) - return SglangConfig(models=models) - - @staticmethod - def from_prefill_num_servers(args) -> "SglangConfig": - """Build a config equivalent to the legacy --prefill-num-servers flag.""" - total_gpus = args.rollout_num_gpus - prefill_gpus = args.prefill_num_servers * args.rollout_num_gpus_per_engine - decode_gpus = total_gpus - prefill_gpus - assert decode_gpus > 0, f"No decode GPUs: total {total_gpus}, prefill {prefill_gpus}" - return SglangConfig( - models=[ - ModelConfig( - name="default", - server_groups=[ - ServerGroupConfig(worker_type="prefill", num_gpus=prefill_gpus), - ServerGroupConfig(worker_type="decode", num_gpus=decode_gpus), - ], - ) - ] - ) - - @property - def has_pd_disaggregation(self) -> bool: - return any(m.has_pd_disaggregation for m in self.models) - - @property - def total_num_gpus(self) -> int: - return sum(m.total_num_gpus for m in self.models) - - - + model_path: NotRequired[str] + # total number of gpus for rollout + num_gpus: NotRequired[int] + num_gpus_per_engine: NotRequired[int] + random_seed: NotRequired[int] + skip_tokenizer_init: NotRequired[bool] + disable_cuda_graph: NotRequired[bool] + disable_radix_cache: NotRequired[bool] + disable_cuda_graph_padding: NotRequired[bool] + # Enabling piecewise CUDA graph (i.e. setting this to False) currently crashes with + # "illegal memory access", likely due to torch 2.10 + sglang incompatibility. + # Defaulted to True (disabled) in sglang_worker.py until the upstream sglang fork is updated. + disable_piecewise_cuda_graph: NotRequired[bool] + enable_nccl_nvls: NotRequired[bool] + disable_outlines_disk_cache: NotRequired[bool] + disable_custom_all_reduce: NotRequired[bool] + disable_overlap_schedule: NotRequired[bool] + enable_mixed_chunk: NotRequired[bool] + enable_dp_attention: NotRequired[bool] + enable_deepep_moe: NotRequired[bool] + enable_ep_moe: NotRequired[bool] + enable_torch_compile: NotRequired[bool] + torch_compile_max_bs: NotRequired[int] + cuda_graph_max_bs: NotRequired[int | None] + cuda_graph_bs: NotRequired[list[int] | None] + torchao_config: NotRequired[str] + enable_nan_detection: NotRequired[bool] + enable_p2p_check: NotRequired[bool] + triton_attention_reduce_in_fp32: NotRequired[bool] + triton_attention_num_kv_splits: NotRequired[int] + num_continuous_decode_steps: NotRequired[int] + enable_memory_saver: NotRequired[bool] + allow_auto_truncate: NotRequired[bool] + attention_backend: NotRequired[str | None] + enable_multimodal: NotRequired[bool] + sampling_backend: NotRequired[str | None] + context_length: NotRequired[int | None] + mem_fraction_static: NotRequired[float | None] + max_running_requests: NotRequired[int | None] + chunked_prefill_size: NotRequired[int | None] + max_prefill_tokens: NotRequired[int] + schedule_policy: NotRequired[str] + schedule_conservativeness: NotRequired[float] + cpu_offload_gb: NotRequired[int] + dtype: NotRequired[str] + kv_cache_dtype: NotRequired[str] + dp_size: NotRequired[int] # only used for dp attention + pp_size: NotRequired[int] # pipeline parallel size + ep_size: NotRequired[int] + # lora + enable_lora: NotRequired[bool | None] + max_lora_rank: NotRequired[int | None] + lora_target_modules: NotRequired[list[str] | None] + lora_paths: NotRequired[list[str] | None] + max_loaded_loras: NotRequired[int] + max_loras_per_batch: NotRequired[int] + lora_backend: NotRequired[str] + # logging + log_level: NotRequired[str] + log_level_http: NotRequired[str | None] + log_requests: NotRequired[bool] + log_requests_level: NotRequired[int] + show_time_cost: NotRequired[bool] + enable_metrics: NotRequired[bool] # Exports Prometheus-like metrics + # The interval (in decoding iterations) to log throughput + # and update prometheus metrics + decode_log_interval: NotRequired[int] + # Extra loader arguments + enable_multithread_load: NotRequired[bool] + enable_fast_load: NotRequired[bool] + # Server warmup + skip_server_warmup: NotRequired[bool] + # Router Arg + sglang_router_ip: NotRequired[str] + sglang_router_port: NotRequired[int] + + +class SGLangConfig(GenerationConfig): + """Configuration for SGLang runtime.""" + + sglang_cfg: SglangSpecificArgs + sglang_kwargs: NotRequired[dict[str, Any]] diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index eed889cba4..0cc5997554 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -50,7 +50,7 @@ # ServerGroup / RolloutServer abstractions # --------------------------------------------------------------------------- -# # use_unified_pg = True for Nemo +# use_unified_pg = True for Nemo # if use_unified_pg: # # Create a single unified placement group for cross-node model parallelism # all_bundles = [] @@ -66,6 +66,11 @@ # ) # ] +# pg = cluster._init_placement_groups(strategy="PACK", use_unified_pg=True)[0] +# pg_reordered_bundle_indices = cluster._get_sorted_bundle_indices() +# + + @dataclasses.dataclass class ServerGroup: """A group of homogeneous SGLang engines with the same configuration. @@ -335,7 +340,6 @@ def get_updatable_engines_and_lock(self): def generate(self, rollout_id): start_time = time.time() self.rollout_id = rollout_id - self.health_monitoring_resume() data, metrics = self._get_rollout_data(rollout_id=rollout_id) self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False) _log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time) @@ -346,7 +350,6 @@ def eval(self, rollout_id): if self.args.debug_train_only: # if debug train only, we don't generate evaluation data return - self.health_monitoring_resume() if self.use_experimental_refactor: result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) @@ -814,7 +817,7 @@ def _compute_megatron_num_gpus(args) -> int: return num -def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: +def start_rollout_servers(args, pg) -> ServerGroup: """Start rollout servers: one per model, each with its own router. Returns a dict mapping model name -> ``RolloutServer``. diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 6087dcbb19..e7a9e1fb83 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -17,16 +17,9 @@ logger = logging.getLogger(__name__) -def get_base_gpu_id(args, rank): - num_gpus = min(args.num_gpus_per_node, args.rollout_num_gpus_per_engine) - if args.colocate: - start_index = (rank * num_gpus) % args.num_gpus_per_node - else: - num_actor_gpus = 0 if args.debug_rollout_only else args.actor_num_gpus_per_node * args.actor_num_nodes - start_index = (num_actor_gpus + rank * num_gpus) % args.num_gpus_per_node - if args.use_critic: - num_critic_gpus = args.critic_num_gpus_per_node * args.critic_num_nodes - start_index = (num_actor_gpus + num_critic_gpus + rank * num_gpus) % args.num_gpus_per_node +def get_base_gpu_id(cluster_cfg, sgl_cfg, rank): + num_gpus = min(cluster_cfg["gpus_per_node"], sgl_cfg["num_gpus_per_engine"]) + start_index = (rank * num_gpus) % cluster_cfg["gpus_per_node"] return start_index def _to_local_gpu_id(physical_gpu_id: int) -> int: @@ -104,14 +97,16 @@ def _wait_server_healthy(base_url, api_key, is_process_alive): class SGLangEngine: def __init__( self, - args, + cluster_cfg, + sgl_cfg, rank: int, worker_type: str = "regular", base_gpu_id: int | None = None, sglang_overrides: dict | None = None, num_gpus_per_engine: int | None = None, - ): - self.args = args + ): + self.cluster_cfg = cluster_cfg + self.sgl_cfg = sgl_cfg self.rank = rank self.worker_type = worker_type self.base_gpu_id = base_gpu_id @@ -127,15 +122,9 @@ def init( router_ip=None, router_port=None, ): - if env_report := self.args.env_report: - collect_and_print_node_env_report( - role="rollout", - rank=self.rank, - partial_env_report=env_report, - ) - self.router_ip = router_ip if router_ip is not None else self.args.sglang_router_ip - self.router_port = router_port if router_port is not None else self.args.sglang_router_port + self.router_ip = router_ip if router_ip is not None else self.sgl_cfg["sglang_router_ip"] + self.router_port = router_port if router_port is not None else self.sgl_cfg["sglang_router_port"] host = host or get_host_info()[1] @@ -153,8 +142,9 @@ def _format_v6_uri(addr): ip_part, port_part = dist_init_addr.rsplit(":", 1) dist_init_addr = f"{_format_v6_uri(ip_part)}:{port_part}" - server_args_dict, _ = _compute_server_args( - self.args, + server_args_dict = _compute_server_args( + self.cluster_cfg, + self.sgl_cfg, self.rank, dist_init_addr, nccl_port, @@ -177,7 +167,7 @@ def _init_normal(self, server_args_dict): self.process = launch_server_process(ServerArgs(**server_args_dict)) if self.node_rank == 0 and self.router_ip and self.router_port: - if parse(sglang_router.__version__) <= parse("0.2.1") or self.args.use_miles_router: + if parse(sglang_router.__version__) <= parse("0.2.1"): assert ( self.worker_type == "regular" ), "pd disaggregation is not supported in old router or miles router." @@ -296,7 +286,7 @@ def shutdown(self): if self.node_rank == 0: worker_url = f"http://{self.server_host}:{self.server_port}" response = None - if parse(sglang_router.__version__) <= parse("0.2.1") or self.args.use_miles_router: + if parse(sglang_router.__version__) <= parse("0.2.1"): response = requests.post( f"http://{self.router_ip}:{self.router_port}/remove_worker?url=http://{self.server_host}:{self.server_port}" ) @@ -475,7 +465,8 @@ def simulate_crash(self): def _compute_server_args( - args, + cluster_cfg, + sgl_cfg, rank, dist_init_addr, nccl_port, @@ -485,17 +476,17 @@ def _compute_server_args( sglang_overrides: dict | None = None, num_gpus_per_engine: int | None = None, ): - _gpus_per_engine = num_gpus_per_engine or args.rollout_num_gpus_per_engine - nnodes = max(1, _gpus_per_engine // args.num_gpus_per_node) + _gpus_per_engine = num_gpus_per_engine or sgl_cfg["num_gpus_per_engine"] + nnodes = max(1, _gpus_per_engine // cluster_cfg["gpus_per_node"]) node_rank = rank % nnodes - base = base_gpu_id if base_gpu_id is not None else get_base_gpu_id(args, rank) + base = base_gpu_id if base_gpu_id is not None else get_base_gpu_id(cluster_cfg, sgl_cfg, rank) base = _to_local_gpu_id(base) kwargs = { - "model_path": args.hf_checkpoint, + "model_path": sgl_cfg["model_path"], "trust_remote_code": True, - "random_seed": args.seed + rank, + "random_seed": sgl_cfg["random_seed"] + rank, # memory - "enable_memory_saver": args.offload_rollout, + "enable_memory_saver": sgl_cfg["enable_memory_saver"], # distributed "host": host, "port": port, @@ -507,29 +498,39 @@ def _compute_server_args( "base_gpu_id": base, # parallel "tp_size": _gpus_per_engine, - "dp_size": args.sglang_dp_size, - "pp_size": args.sglang_pp_size, - "ep_size": args.sglang_ep_size, + "dp_size": sgl_cfg["dp_size"], + "pp_size": sgl_cfg["pp_size"], + "ep_size": sgl_cfg["ep_size"], # always skip warmup to prevent warmup timeout. "skip_server_warmup": True, # always enable draft weights cpu backup so that we run training without mtp weights. "enable_draft_weights_cpu_backup": True, } - if sglang_overrides: - kwargs.update(sglang_overrides) - - if args.use_rollout_routing_replay: - kwargs["enable_return_routed_experts"] = True - if args.fp16: - kwargs["dtype"] = "float16" - external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] - + for key in [ + "dtype", + "kv_cache_dtype", + "context_length", + "max_running_requests", + "chunked_prefill_size", + "max_prefill_tokens", + "schedule_policy", + "schedule_conservativeness", + "cpu_offload_gb", + "log_level", + "mem_fraction_static", + "allow_auto_truncate", + "disable_piecewise_cuda_graph", + ]: + if key in sgl_cfg: + kwargs[key] = sgl_cfg[key] unused_keys = set(kwargs.keys()) + for attr in dataclasses.fields(ServerArgs): - if hasattr(args, f"sglang_{attr.name}") and attr.name not in kwargs: - kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") + sgl_key = f"sglang_{attr.name}" + if sgl_key in sgl_cfg and attr.name not in kwargs: + kwargs[attr.name] = sgl_cfg[sgl_key] unused_keys.discard(attr.name) # for compatibility with old args @@ -538,16 +539,5 @@ def _compute_server_args( for key in unused_keys: kwargs.pop(key) - return kwargs, external_engine_need_check_fields - + return kwargs -_EXTERNAL_ENGINE_SKIP_CHECK_FIELDS = [ - "model_path", - "trust_remote_code", - "random_seed", - "nccl_port", - "dist_init_addr", - "skip_server_warmup", - "enable_draft_weights_cpu_backup", - "mem_fraction_static", -] From 2501526f866889ee3d4bc14bf57f84762591a44b Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Wed, 8 Apr 2026 17:20:14 -0700 Subject: [PATCH 06/36] update --- nemo_rl/distributed/virtual_cluster.py | 48 +++++ nemo_rl/models/generation/redesign/config.py | 18 +- .../generation/redesign/sglang_generation.py | 159 ++++---------- .../generation/redesign/sglang_worker.py | 62 ++---- nemo_rl/models/generation/redesign/utils.py | 198 +++++++++++++++++- 5 files changed, 313 insertions(+), 172 deletions(-) diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 96282ad623..b609342e5b 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -185,6 +185,54 @@ def get_gpu_id(self): return ray.get_gpu_ids()[0] +def get_reordered_bundle_and_gpu_ids( + pg: PlacementGroup, +) -> tuple[list[int], list[int]]: + """Return bundle indices and GPU IDs sorted by (node_id, gpu_id). + + Uses ``GetGPUIDActor`` to discover the physical GPU ID assigned to each + bundle, then sorts identically to the pattern in ``pg.py``. + + Returns: + (reordered_bundle_indices, reordered_gpu_ids) + """ + pg_data = placement_group_table(pg) + num_bundles = len(pg_data["bundles"]) + bundle_to_node_ids = pg_data["bundles_to_node_id"] + + info_actors = [] + for i in range(num_bundles): + info_actors.append( + GetGPUIDActor.options( + num_cpus=0.01, + num_gpus=0.01, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=i, + ), + ).remote() + ) + + gpu_ids = ray.get([actor.get_gpu_id.remote() for actor in info_actors]) + for actor in info_actors: + ray.kill(actor) + + bundle_infos = [(i, bundle_to_node_ids[i], gpu_ids[i]) for i in range(num_bundles)] + sorted_infos = sorted(bundle_infos, key=lambda x: (x[1], x[2])) + + reordered_bundle_indices = [info[0] for info in sorted_infos] + reordered_gpu_ids = [gpu_ids[info[0]] for info in sorted_infos] + + for i, info in enumerate(sorted_infos): + actual_idx = info[0] + logger.info( + f" bundle {i:4}, actual_bundle_index: {actual_idx:4}, " + f"node: {info[1]}, gpu: {gpu_ids[actual_idx]}" + ) + + return reordered_bundle_indices, reordered_gpu_ids + + class ResourceInsufficientError(Exception): """Exception raised when the cluster does not have enough resources to satisfy the requested configuration.""" diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py index 63d51b04ae..8e24bb272a 100644 --- a/nemo_rl/models/generation/redesign/config.py +++ b/nemo_rl/models/generation/redesign/config.py @@ -16,7 +16,6 @@ from nemo_rl.models.generation.interfaces import GenerationConfig - class SglangSpecificArgs(TypedDict): """SGLang-specific configuration arguments. @@ -25,9 +24,7 @@ class SglangSpecificArgs(TypedDict): """ model_path: NotRequired[str] - # total number of gpus for rollout - num_gpus: NotRequired[int] - num_gpus_per_engine: NotRequired[int] + # Total number of gpus for rollout random_seed: NotRequired[int] skip_tokenizer_init: NotRequired[bool] disable_cuda_graph: NotRequired[bool] @@ -96,13 +93,20 @@ class SglangSpecificArgs(TypedDict): enable_fast_load: NotRequired[bool] # Server warmup skip_server_warmup: NotRequired[bool] - # Router Arg + +class SGLangServer(TypedDict): + sglang_server_concurrency: int + num_gpus: NotRequired[int] + num_gpus_per_engine: NotRequired[int] + +class SGLangRouter(TypedDict): sglang_router_ip: NotRequired[str] sglang_router_port: NotRequired[int] - + router_policy: NotRequired[str] class SGLangConfig(GenerationConfig): """Configuration for SGLang runtime.""" - sglang_cfg: SglangSpecificArgs + sglang_server: SGLangServer + sglang_router: SGLangRouter sglang_kwargs: NotRequired[dict[str, Any]] diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 0cc5997554..b2d6f53f33 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -68,7 +68,6 @@ # pg = cluster._init_placement_groups(strategy="PACK", use_unified_pg=True)[0] # pg_reordered_bundle_indices = cluster._get_sorted_bundle_indices() -# @dataclasses.dataclass @@ -80,15 +79,13 @@ class ServerGroup: in PD disaggregation). """ - args: Any pg: Any # (placement_group, reordered_bundle_indices, reordered_gpu_ids) all_engines: list num_gpus_per_engine: int + num_gpus_per_node: int num_new_engines: int - worker_type: str = "regular" # "regular", "prefill", or "decode" rank_offset: int = 0 gpu_offset: int = 0 - sglang_overrides: dict = dataclasses.field(default_factory=dict) needs_offload: bool = False model_path: str | None = None router_ip: str | None = None @@ -96,7 +93,7 @@ class ServerGroup: @property def nodes_per_engine(self): - return max(1, self.num_gpus_per_engine // self.args.num_gpus_per_node) + return max(1, self.num_gpus_per_engine // self.num_gpus_per_node) @property def engines(self): @@ -127,14 +124,9 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis """ if port_cursors is None: port_cursors = {} - if self.args.debug_train_only or self.worker_type == "placeholder": - self.num_new_engines = 0 - return [], port_cursors - - num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) + num_gpu_per_engine = min(self.num_gpus_per_engine, self.num_gpus_per_node) pg, reordered_bundle_indices, reordered_gpu_ids = self.pg - RolloutRayActor = ray.remote(SGLangEngine) rollout_engines = [] @@ -178,9 +170,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis ).remote( self.args, rank=global_rank, - worker_type=self.worker_type, base_gpu_id=base_gpu_id, - sglang_overrides=self.sglang_overrides, num_gpus_per_engine=self.num_gpus_per_engine, ) @@ -194,9 +184,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis base_port = max(port_cursors.values()) if port_cursors else 15000 addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( - args=self.args, rollout_engines=rollout_engines, - worker_type=self.worker_type, num_gpus_per_engine=self.num_gpus_per_engine, rank_offset=self.rank_offset, base_port=base_port, @@ -228,7 +216,6 @@ def recover(self): updatable_new_engines = [] non_updatable_groups_engines: list[tuple[str, list]] = [] for g, dead_indices in zip(self.server_groups, dead_per_group, strict=True): - logger.info(f"Recovered {g.num_new_engines} dead rollout engines (worker_type={g.worker_type})") assert g.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" if g.needs_offload and dead_indices: new_engines = [g.all_engines[i] for i in dead_indices] @@ -280,21 +267,25 @@ def onload_weights_from_disk(self): ] # --------------------------------------------------------------------------- -# RolloutManager +# SGLangGeneration # --------------------------------------------------------------------------- -@ray.remote -class RolloutManager: +class SGLangGeneration(GenerationInterface): """The class to run rollout and convert rollout data to training data.""" - def __init__(self, args, pg): + def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglang_cfg: SGLangConfig): configure_logger() - self.pg = pg - self.args = args - # TODO make args immutable - init_tracking(args, primary=False, router_addr=f"http://{args.sglang_router_ip}:{args.sglang_router_port}") + self.cluster = cluster + self.cluster_cfg = cluster_cfg + self.sglang_cfg = sglang_cfg + self.pg = cluster._init_placement_groups( + strategy="PACK", + use_unified_pg=True, + ) + self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids = get_reordered_bundle_and_gpu_ids(self.pg) + # TODO: change to each implementation self.generate_rollout = load_function(self.args.rollout_function_path) self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = None @@ -308,11 +299,8 @@ def __init__(self, args, pg): logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.") logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") - if self.args.debug_train_only: - self.server_group = None - else: - init_http_client(args) - self.server_group = start_rollout_servers(args, pg) + init_http_client(args) + self.server_group = start_rollout_servers(args, (self.pg, self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids)) self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() self.rollout_id = -1 @@ -347,10 +335,6 @@ def generate(self, rollout_id): return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"]) def eval(self, rollout_id): - if self.args.debug_train_only: - # if debug train only, we don't generate evaluation data - return - if self.use_experimental_refactor: result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) else: @@ -371,11 +355,13 @@ def offload(self, tags: list[str] | None = None): if engine is not None ] return ray.get(handles) if handles else [] - self.server_group.offload() - + else: + handles = self.server_group.offload() + return ray.get(handles) if handles else [] def onload(self, tags: list[str] | None = None): - self.server_group.onload(tags) + handles = self.server_group.onload(tags) + return ray.get(handles) if handles else [] def onload_weights(self): self.server_group.onload_weights() @@ -663,10 +649,9 @@ def _split_train_data_by_dp(self, data, dp_size): def _allocate_rollout_engine_addr_and_ports_normal( *, - args, + cluster_cfg, + sglang_cfg, rollout_engines, - worker_type="regular", - num_gpus_per_engine=None, rank_offset=0, base_port=15000, ): @@ -676,8 +661,13 @@ def _allocate_rollout_engine_addr_and_ports_normal( # 2. nccl port # 3. dist_init_addr port # 4. other ports for dp_attention, which is of size 4 + dp_size - _gpus_per_engine = num_gpus_per_engine or args.rollout_num_gpus_per_engine - num_engines_per_node = max(1, args.num_gpus_per_node // _gpus_per_engine) + + sglang_dp_size = sglang_cfg["sglang_cfg"]["dp_size"] + num_gpus_per_engine = sglang_cfg["sglang_server"]["num_gpus_per_engine"] + num_gpus_per_node = cluster_cfg["gpus_per_node"] + + _gpus_per_engine = num_gpus_per_engine + num_engines_per_node = max(1, num_gpus_per_node // _gpus_per_engine) addr_and_ports: dict[int, dict] = {} # Track per-node port cursors so that different server groups (called @@ -727,19 +717,16 @@ def addr(): addr_and_ports[current_rank]["port"] = get_port() addr_and_ports[current_rank]["nccl_port"] = get_port() - if worker_type == "prefill": - addr_and_ports[current_rank]["disaggregation_bootstrap_port"] = get_port() - - if _gpus_per_engine > args.num_gpus_per_node: - num_node_per_engine = _gpus_per_engine // args.num_gpus_per_node + if _gpus_per_engine > num_gpus_per_node: + num_node_per_engine = _gpus_per_engine // num_gpus_per_node if local_rank % num_node_per_engine == 0: - dist_init_addr = f"{get_addr()}:{get_port(30 + args.sglang_dp_size)}" + dist_init_addr = f"{get_addr()}:{get_port(30 + sglang_dp_size)}" for i in range(num_node_per_engine): addr_and_ports.setdefault(rank + i, {}) addr_and_ports[rank + i]["dist_init_addr"] = dist_init_addr else: for i in range(num_engines_on_this_node): - addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(30 + args.sglang_dp_size)}" + addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(30 + sglang_dp_size)}" for i, _ in rollout_engines: for key in ["port", "nccl_port", "dist_init_addr"]: @@ -752,7 +739,7 @@ def addr(): # Router + server bootstrap # --------------------------------------------------------------------------- -def _start_router(args, *, force_new: bool = False) -> tuple[str, int]: +def _start_router(args: SGLangConfig, *, force_new: bool = False) -> tuple[str, int]: """Start sgl router or miles router and return (router_ip, router_port). If ``args.sglang_router_ip`` is already set and ``force_new`` is False, @@ -771,11 +758,12 @@ def _start_router(args, *, force_new: bool = False) -> tuple[str, int]: from sglang_router.launch_router import RouterArgs - from miles.utils.http_utils import run_router - - router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) + # pass from + router_args = RouterArgs() router_args.host = router_ip router_args.port = router_port + if args["sglang_router"]["router_policy"] is not None: + router_args.router_policy = args["sglang_router"]["router_policy"] router_args.prometheus_port = find_available_port(random.randint(4000, 5000)) router_args.log_level = "warn" router_args.request_timeout_secs = args.sglang_router_request_timeout_secs @@ -793,30 +781,6 @@ def _start_router(args, *, force_new: bool = False) -> tuple[str, int]: logger.info(f"Router launched at {router_ip}:{router_port}") return router_ip, router_port -def _compute_rollout_offset(args) -> int: - """Offset (in PG bundle slots) where rollout GPUs start.""" - if args.debug_train_only or args.debug_rollout_only or args.colocate: - return 0 - if getattr(args, "critic_train_only", False): - return args.critic_num_nodes * args.critic_num_gpus_per_node - offset = args.actor_num_nodes * args.actor_num_gpus_per_node - if getattr(args, "use_critic", False): - offset += args.critic_num_nodes * args.critic_num_gpus_per_node - return offset - - -def _compute_megatron_num_gpus(args) -> int: - """Total number of megatron (actor + critic) GPU slots in the placement group.""" - if getattr(args, "debug_rollout_only", False): - return 0 - if getattr(args, "critic_train_only", False): - return args.critic_num_nodes * args.critic_num_gpus_per_node - num = args.actor_num_nodes * args.actor_num_gpus_per_node - if getattr(args, "use_critic", False): - num += args.critic_num_nodes * args.critic_num_gpus_per_node - return num - - def start_rollout_servers(args, pg) -> ServerGroup: """Start rollout servers: one per model, each with its own router. @@ -824,12 +788,9 @@ def start_rollout_servers(args, pg) -> ServerGroup: """ config = _resolve_sglang_config(args) - servers: dict[str, RolloutServer] = {} gpu_offset = 0 engine_offset = 0 - - rollout_pg_offset = _compute_rollout_offset(args) - megatron_num_gpus = _compute_megatron_num_gpus(args) + rollout_pg_offset = 0 for model_idx, model_cfg in enumerate(config.models): model_cfg.resolve(args) @@ -850,28 +811,18 @@ def start_rollout_servers(args, pg) -> ServerGroup: num_gpu_per_engine_local = min(gpus_per_engine, args.num_gpus_per_node) num_engines = group_cfg.num_gpus // num_gpu_per_engine_local - group_abs_start = rollout_pg_offset + gpu_offset - needs_offload = args.offload_rollout and group_abs_start < megatron_num_gpus - overrides = dict(group_cfg.overrides) - if args.offload_rollout and not needs_offload: - overrides.setdefault("enable_memory_saver", False) - logger.info( - f"Engine group '{group_cfg.worker_type}' gpu_offset={gpu_offset} " - f"(abs={group_abs_start}): needs_offload={needs_offload}" - ) + needs_offload = args.offload_rollout group = ServerGroup( args=args, pg=pg, - all_engines=[None] * num_engines if group_cfg.worker_type != "placeholder" else [], + all_engines=[None] * num_engines, num_gpus_per_engine=gpus_per_engine, num_new_engines=0, - worker_type=group_cfg.worker_type, rank_offset=engine_offset, gpu_offset=gpu_offset, - sglang_overrides=overrides, needs_offload=needs_offload, - model_path=overrides.get("model_path", args.hf_checkpoint), + model_path= sglang_cfg["sglang_cfg"]["model_path"], router_ip=router_ip, router_port=router_port, ) @@ -897,28 +848,6 @@ def start_rollout_servers(args, pg) -> ServerGroup: return servers -def _resolve_sglang_config(args) -> SglangConfig: - """Build a SglangConfig from args, choosing the right source.""" - if getattr(args, "sglang_config", None) is not None: - config = SglangConfig.from_yaml(args.sglang_config) - expected = args.rollout_num_gpus - actual = config.total_num_gpus - assert actual == expected, f"sglang_config total GPUs ({actual}) != rollout_num_gpus ({expected})" - return config - - if args.prefill_num_servers is not None: - return SglangConfig.from_prefill_num_servers(args) - - return SglangConfig( - models=[ - ModelConfig( - name="default", - server_groups=[ServerGroupConfig(worker_type="regular", num_gpus=args.rollout_num_gpus)], - ) - ] - ) - - # --------------------------------------------------------------------------- # Logging / metrics helpers (unchanged) # --------------------------------------------------------------------------- diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index e7a9e1fb83..7ccd708896 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -17,8 +17,8 @@ logger = logging.getLogger(__name__) -def get_base_gpu_id(cluster_cfg, sgl_cfg, rank): - num_gpus = min(cluster_cfg["gpus_per_node"], sgl_cfg["num_gpus_per_engine"]) +def get_base_gpu_id(cluster_cfg, sglang_cfg, rank): + num_gpus = min(cluster_cfg["gpus_per_node"], sglang_cfg["sglang_server"]["num_gpus_per_engine"]) start_index = (rank * num_gpus) % cluster_cfg["gpus_per_node"] return start_index @@ -98,19 +98,15 @@ class SGLangEngine: def __init__( self, cluster_cfg, - sgl_cfg, + sglang_cfg, rank: int, - worker_type: str = "regular", base_gpu_id: int | None = None, - sglang_overrides: dict | None = None, num_gpus_per_engine: int | None = None, ): self.cluster_cfg = cluster_cfg - self.sgl_cfg = sgl_cfg + self.sglang_cfg = sglang_cfg self.rank = rank - self.worker_type = worker_type self.base_gpu_id = base_gpu_id - self.sglang_overrides = sglang_overrides or {} self.num_gpus_per_engine = num_gpus_per_engine def init( @@ -123,8 +119,8 @@ def init( router_port=None, ): - self.router_ip = router_ip if router_ip is not None else self.sgl_cfg["sglang_router_ip"] - self.router_port = router_port if router_port is not None else self.sgl_cfg["sglang_router_port"] + self.router_ip = router_ip if router_ip is not None else self.sglang_cfg["sglang_cfg"]["sglang_router_ip"] + self.router_port = router_port if router_port is not None else self.sglang_cfg["sglang_cfg"]["sglang_router_port"] host = host or get_host_info()[1] @@ -144,14 +140,13 @@ def _format_v6_uri(addr): server_args_dict = _compute_server_args( self.cluster_cfg, - self.sgl_cfg, + self.sglang_cfg, self.rank, dist_init_addr, nccl_port, host, port, base_gpu_id=self.base_gpu_id, - sglang_overrides=self.sglang_overrides, num_gpus_per_engine=self.num_gpus_per_engine, ) @@ -168,16 +163,13 @@ def _init_normal(self, server_args_dict): if self.node_rank == 0 and self.router_ip and self.router_port: if parse(sglang_router.__version__) <= parse("0.2.1"): - assert ( - self.worker_type == "regular" - ), "pd disaggregation is not supported in old router or miles router." response = requests.post( f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}" ) else: payload = { "url": f"http://{self.server_host}:{self.server_port}", - "worker_type": self.worker_type, + "worker_type": "regular", } response = requests.post( f"http://{self.router_ip}:{self.router_port}/workers", @@ -463,30 +455,28 @@ def simulate_crash(self): logger.info(f"Simulating crash on engine {self.server_host}:{self.server_port}...") self.shutdown() - def _compute_server_args( cluster_cfg, - sgl_cfg, + sglang_cfg, rank, dist_init_addr, nccl_port, host, port, base_gpu_id: int | None = None, - sglang_overrides: dict | None = None, num_gpus_per_engine: int | None = None, ): - _gpus_per_engine = num_gpus_per_engine or sgl_cfg["num_gpus_per_engine"] + _gpus_per_engine = num_gpus_per_engine or sglang_cfg["sglang_server"]["num_gpus_per_engine"] nnodes = max(1, _gpus_per_engine // cluster_cfg["gpus_per_node"]) node_rank = rank % nnodes - base = base_gpu_id if base_gpu_id is not None else get_base_gpu_id(cluster_cfg, sgl_cfg, rank) + base = base_gpu_id if base_gpu_id is not None else get_base_gpu_id(cluster_cfg, sglang_cfg, rank) base = _to_local_gpu_id(base) kwargs = { - "model_path": sgl_cfg["model_path"], + "model_path": sglang_cfg["sglang_cfg"]["model_path"], "trust_remote_code": True, - "random_seed": sgl_cfg["random_seed"] + rank, + "random_seed": sglang_cfg["sglang_cfg"]["random_seed"] + rank, # memory - "enable_memory_saver": sgl_cfg["enable_memory_saver"], + "enable_memory_saver": sglang_cfg["sglang_cfg"]["enable_memory_saver"], # distributed "host": host, "port": port, @@ -498,9 +488,9 @@ def _compute_server_args( "base_gpu_id": base, # parallel "tp_size": _gpus_per_engine, - "dp_size": sgl_cfg["dp_size"], - "pp_size": sgl_cfg["pp_size"], - "ep_size": sgl_cfg["ep_size"], + "dp_size": sglang_cfg["sglang_cfg"]["dp_size"], + "pp_size": sglang_cfg["sglang_cfg"]["pp_size"], + "ep_size": sglang_cfg["sglang_cfg"]["ep_size"], # always skip warmup to prevent warmup timeout. "skip_server_warmup": True, # always enable draft weights cpu backup so that we run training without mtp weights. @@ -522,22 +512,8 @@ def _compute_server_args( "allow_auto_truncate", "disable_piecewise_cuda_graph", ]: - if key in sgl_cfg: - kwargs[key] = sgl_cfg[key] - - unused_keys = set(kwargs.keys()) - - for attr in dataclasses.fields(ServerArgs): - sgl_key = f"sglang_{attr.name}" - if sgl_key in sgl_cfg and attr.name not in kwargs: - kwargs[attr.name] = sgl_cfg[sgl_key] - unused_keys.discard(attr.name) - - # for compatibility with old args - if len(unused_keys) > 0: - logger.info(f"Warning: The following arguments is not supported in the current sglang: {unused_keys}.") - for key in unused_keys: - kwargs.pop(key) + if key in sglang_cfg["sglang_cfg"]: + kwargs[key] = sglang_cfg["sglang_cfg"][key] return kwargs diff --git a/nemo_rl/models/generation/redesign/utils.py b/nemo_rl/models/generation/redesign/utils.py index a7df1ecc78..64457f64fd 100644 --- a/nemo_rl/models/generation/redesign/utils.py +++ b/nemo_rl/models/generation/redesign/utils.py @@ -1,10 +1,4 @@ -@ray.remote(num_gpus=1) -class InfoActor: - def get_ip_and_gpu_id(self): - return ray.util.get_node_ip_address(), ray.get_gpu_ids()[0] - - def find_available_port(base_port: int): port = base_port + random.randint(100, 1000) while True: @@ -104,4 +98,194 @@ def get_free_port(start_port=10000, consecutive=1): port = start_port while not all(is_port_available(port + i) for i in range(consecutive)): port += 1 - return port \ No newline at end of file + return port + +def _wrap_ipv6(host): + """Wrap IPv6 address in [] if needed.""" + try: + ipaddress.IPv6Address(host.strip("[]")) + return f"[{host.strip('[]')}]" + except ipaddress.AddressValueError: + return host + + +def run_router(args): + try: + from sglang_router.launch_router import launch_router + + router = launch_router(args) + if router is None: + return 1 + return 0 + except Exception as e: + logger.info(e) + return 1 + + +def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None: + """Terminate a process gracefully, with forced kill as fallback. + + Args: + process: The process to terminate + timeout: Seconds to wait for graceful termination before forcing kill + """ + if not process.is_alive(): + return + + process.terminate() + process.join(timeout=timeout) + if process.is_alive(): + process.kill() + process.join() + + +_http_client: httpx.AsyncClient | None = None +_client_concurrency: int = 0 + +# Optional Ray-based distributed POST dispatch +_distributed_post_enabled: bool = False +_post_actors: list[object] = [] +_post_actor_idx: int = 0 + + +def _next_actor(): + global _post_actor_idx + if not _post_actors: + return None + actor = _post_actors[_post_actor_idx % len(_post_actors)] + _post_actor_idx = (_post_actor_idx + 1) % len(_post_actors) + return actor + + +async def _post(client, url, payload, max_retries=60, action="post"): + retry_count = 0 + while retry_count < max_retries: + try: + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) + response.raise_for_status() + try: + output = response.json() + except json.JSONDecodeError: + output = response.text + except Exception as e: + retry_count += 1 + + if isinstance(e, httpx.HTTPStatusError): + response_text = e.response.text + else: + response_text = None + + logger.info( + f"Error: {e}, retrying... (attempt {retry_count}/{max_retries}, url={url}, response={response_text})" + ) + if retry_count >= max_retries: + logger.info(f"Max retries ({max_retries}) reached, failing... (url={url})") + raise e + await asyncio.sleep(1) + continue + break + + return output + + +def init_http_client(args: SglangSpecificArgs): + """Initialize HTTP client and optionally enable distributed POST via Ray.""" + global _http_client, _client_concurrency, _distributed_post_enabled + if not args.rollout_num_gpus: + return + + _client_concurrency = args["sglang_server"]["sglang_server_concurrency"] * args["sglang_server"]["num_gpus"] // args["sglang_server"]["num_gpus_per_engine"] + if _http_client is None: + _http_client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=_client_concurrency), + timeout=httpx.Timeout(None), + ) + + # Optionally initialize distributed POST via Ray without changing interfaces + if args.use_distributed_post: + _init_ray_distributed_post(args) + _distributed_post_enabled = True + + +def _init_ray_distributed_post(args): + """Initialize one or more Ray async actors per node for HTTP POST. + + Uses NodeAffinitySchedulingStrategy to place actors on distinct nodes. + Controlled by MILES_HTTP_POST_ACTORS_PER_NODE. + """ + global _post_actors + if _post_actors: + return # Already initialized + + import ray + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + # Discover alive nodes + nodes = [n for n in ray.nodes() if n.get("Alive")] + if not nodes: + raise RuntimeError("No alive Ray nodes to place HTTP POST actors.") + + # Define the async actor + @ray.remote + class _HttpPosterActor: + def __init__(self, concurrency: int): + # Lazy creation to this actor's event loop + self._client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=max(1, concurrency)), + timeout=httpx.Timeout(None), + ) + + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) + + # Create actors per node + created = [] + # Distribute client concurrency across actors (at least 1 per actor) + per_actor_conc = (_client_concurrency + len(nodes)) // len(nodes) + + for node in nodes: + node_id = node["NodeID"] + scheduling = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) + for _ in range(args["sglang_server"]["num_gpus_per_engine"]): + actor = _HttpPosterActor.options( + name=None, + lifetime="detached", + scheduling_strategy=scheduling, + max_concurrency=per_actor_conc, + # Use tiny CPU to schedule + num_cpus=0.001, + ).remote(per_actor_conc) + created.append(actor) + + _post_actors = created + + +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) +async def post(url, payload, max_retries=60, action="post"): + # If distributed mode is enabled and actors exist, dispatch via Ray. + if _distributed_post_enabled and _post_actors: + try: + import ray + + actor = _next_actor() + if actor is not None: + # Use a thread to avoid blocking the event loop on ray.get + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) + return await asyncio.to_thread(ray.get, obj_ref) + except Exception as e: + logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") + # fall through to local + + return await _post(_http_client, url, payload, max_retries, action=action) + + +# TODO unify w/ `post` to add retries and remote-execution +async def get(url): + response = await _http_client.get(url) + response.raise_for_status() + output = response.json() + return output From 2383916bc2673cbc49e74bad0d20ee2c496f1ef7 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Thu, 9 Apr 2026 14:01:32 -0700 Subject: [PATCH 07/36] update --- nemo_rl/models/generation/redesign/config.py | 1 + .../generation/redesign/sglang_generation.py | 529 +------------ .../generation/redesign/sglang_rollout.py | 704 ++++++++++++++++++ 3 files changed, 743 insertions(+), 491 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/sglang_rollout.py diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py index 8e24bb272a..00f8b68fcb 100644 --- a/nemo_rl/models/generation/redesign/config.py +++ b/nemo_rl/models/generation/redesign/config.py @@ -95,6 +95,7 @@ class SglangSpecificArgs(TypedDict): skip_server_warmup: NotRequired[bool] class SGLangServer(TypedDict): + needs_offload: bool sglang_server_concurrency: int num_gpus: NotRequired[int] num_gpus_per_engine: NotRequired[int] diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index b2d6f53f33..8239521d26 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -285,27 +285,11 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan ) self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids = get_reordered_bundle_and_gpu_ids(self.pg) - # TODO: change to each implementation - self.generate_rollout = load_function(self.args.rollout_function_path) - self.eval_generate_rollout = load_function(self.args.eval_function_path) - self.custom_reward_post_process_func = None - if self.args.custom_reward_post_process_path is not None: - self.custom_reward_post_process_func = load_function(self.args.custom_reward_post_process_path) - self.custom_convert_samples_to_train_data_func = None - if self.args.custom_convert_samples_to_train_data_path is not None: - self.custom_convert_samples_to_train_data_func = load_function( - self.args.custom_convert_samples_to_train_data_path - ) - logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.") - logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") - init_http_client(args) self.server_group = start_rollout_servers(args, (self.pg, self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids)) self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() - self.rollout_id = -1 - self._metric_checker = MetricChecker.maybe_create(args) def dispose(self): if self._metric_checker is not None: @@ -325,28 +309,6 @@ def get_updatable_engines_and_lock(self): num_new = server_group.num_new_engines if server_group else 0 return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets - def generate(self, rollout_id): - start_time = time.time() - self.rollout_id = rollout_id - data, metrics = self._get_rollout_data(rollout_id=rollout_id) - self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False) - _log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time) - data = self._convert_samples_to_train_data(data) - return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"]) - - def eval(self, rollout_id): - if self.use_experimental_refactor: - result = call_rollout_function(self.eval_generate_rollout, RolloutFnEvalInput(rollout_id=rollout_id)) - else: - result = call_rollout_fn( - self.eval_generate_rollout, self.args, rollout_id, self.data_source, evaluation=True - ) - data = result.data - self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) - metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) - if self._metric_checker is not None: - self._metric_checker.on_eval(metrics) - def offload(self, tags: list[str] | None = None): if tags is not None: handles = [ @@ -399,249 +361,8 @@ def clear_updatable_num_new_engines(self): def check_weights(self, action: str): return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) - def _get_rollout_data(self, rollout_id): - if self.args.load_debug_rollout_data: - data = torch.load( - self.args.load_debug_rollout_data.format(rollout_id=rollout_id), - weights_only=False, - )["samples"] - data = [Sample.from_dict(sample) for sample in data] - if (ratio := self.args.load_debug_rollout_data_subsample) is not None: - original_num_rows = len(data) - rough_subsample_num_rows = int(original_num_rows * ratio) - data = data[: rough_subsample_num_rows // 2] + data[-rough_subsample_num_rows // 2 :] - logger.info( - f"Subsample loaded debug rollout data using {ratio=} and change num rows {original_num_rows} -> {len(data)}" - ) - metrics = None - else: - if self.use_experimental_refactor: - data = call_rollout_function(self.generate_rollout, RolloutFnTrainInput(rollout_id=rollout_id)) - else: - data = call_rollout_fn( - self.generate_rollout, self.args, rollout_id, self.data_source, evaluation=False - ) - metrics = data.metrics - data = data.samples - # flatten the data if it is a list of lists - while isinstance(data[0], list): - data = list(itertools.chain.from_iterable(data)) - - if not self.args.disable_rollout_trim_samples: - global_batch_size = self.args.global_batch_size - if self.args.use_dynamic_global_batch_size: - logger.info(f"Collected {len(data)} samples from rollout to train with dynamic global batch size") - # TODO: this is a temporary solution, we should directly save dynamic_global_batch_size to rollout data - self._dynamic_global_batch_size = self._compute_dynamic_global_batch_size(len(data)) - global_batch_size = self._dynamic_global_batch_size - - if len(data) % global_batch_size != 0: - trim_len = (len(data) // global_batch_size) * global_batch_size - if trim_len == 0: - raise ValueError(f"Not enough samples {len(data)} for global_batch_size {global_batch_size}") - origin_data_length = len(data) - data = data[:trim_len] - logger.info(f"trim number of samples from {origin_data_length} to {trim_len}") - logger.info(f"Final collected {len(data)} samples from rollout to train") - - return data, metrics - - def _compute_dynamic_global_batch_size(self, num_samples: int) -> int: - """Calculate dynamic global_batch_size to ensure only one training step. - - Strategy: global_batch_size = num_samples rounded down to a multiple of dp_size - This ensures num_steps_per_rollout = num_samples // global_batch_size = 1 - """ - dp_size = self.train_parallel_config["dp_size"] - original_gbs = self.args.global_batch_size - - # Round down to a multiple of dp_size to ensure only one training step - dynamic_gbs = (num_samples // dp_size) * dp_size - - if dynamic_gbs == 0: - # Too few samples, use at least dp_size - dynamic_gbs = dp_size - logger.warning(f"num_samples={num_samples} < dp_size={dp_size}, using dp_size as global_batch_size") - - # Calculate how many samples will be discarded - wasted = num_samples - dynamic_gbs - - if dynamic_gbs != original_gbs or wasted > 0: - logger.info( - f"Dynamic global_batch_size: {original_gbs} -> {dynamic_gbs} " - f"(num_samples={num_samples}, dp_size={dp_size}, " - f"num_steps=1, wasted={wasted})" - ) - - return dynamic_gbs - - def _save_debug_rollout_data(self, data, rollout_id, evaluation: bool): - # TODO to be refactored (originally Buffer._set_data) - if (path_template := self.args.save_debug_rollout_data) is not None: - path = Path(path_template.format(rollout_id=("eval_" if evaluation else "") + str(rollout_id))) - logger.info(f"Save debug rollout data to {path}") - path.parent.mkdir(parents=True, exist_ok=True) - - # TODO may improve the format - if evaluation: - dump_data = dict( - samples=[sample.to_dict() for dataset_name, info in data.items() for sample in info["samples"]] - ) - else: - dump_data = dict( - samples=[sample.to_dict() for sample in data], - ) - - torch.save(dict(rollout_id=rollout_id, **dump_data), path) - - def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): - if self.custom_reward_post_process_func is not None: - return self.custom_reward_post_process_func(self.args, samples) - - raw_rewards = [sample.get_reward_value(self.args) for sample in samples] - if ( - self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] - and self.args.rewards_normalization - ): - # group norm - rewards = torch.tensor(raw_rewards, dtype=torch.float) - if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size: - rewards = rewards.reshape(-1, self.args.n_samples_per_prompt) - else: - # when samples count are not equal in each group - rewards = rewards.view(-1, rewards.shape[-1]) - mean = rewards.mean(dim=-1, keepdim=True) - rewards = rewards - mean - - if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization: - std = rewards.std(dim=-1, keepdim=True) - rewards = rewards / (std + 1e-6) - - return raw_rewards, rewards.flatten().tolist() - - return raw_rewards, raw_rewards - - def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]): - """ - Convert inference generated samples to training data. - """ - if self.custom_convert_samples_to_train_data_func is not None: - return self.custom_convert_samples_to_train_data_func(self.args, samples) - - raw_rewards, rewards = self._post_process_rewards(samples) - - assert len(raw_rewards) == len(samples) - assert len(rewards) == len(samples) - - train_data = { - "tokens": [sample.tokens for sample in samples], - "response_lengths": [sample.response_length for sample in samples], - # some reward model, e.g. remote rm, may return multiple rewards, - # we could use key to select the reward. - "rewards": rewards, - "raw_reward": raw_rewards, - "truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples], - "sample_indices": [sample.index for sample in samples], - } - - # loss mask - # TODO: compress the loss mask - loss_masks = [] - for sample in samples: - # always instantiate loss_mask if not provided - if sample.loss_mask is None: - sample.loss_mask = [1] * sample.response_length - - assert ( - len(sample.loss_mask) == sample.response_length - ), f"loss mask length {len(sample.loss_mask)} != response length {sample.response_length}" - if sample.remove_sample: - sample.loss_mask = [0] * sample.response_length - loss_masks.append(sample.loss_mask) - train_data["loss_masks"] = loss_masks - - # overwriting the raw reward - if samples[0].metadata and "raw_reward" in samples[0].metadata: - train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples] - - # For rollout buffer - if samples[0].metadata and "round_number" in samples[0].metadata: - train_data["round_number"] = [sample.metadata["round_number"] for sample in samples] - - # Add rollout log probabilities for off-policy correction - if samples[0].rollout_log_probs is not None: - train_data["rollout_log_probs"] = [sample.rollout_log_probs for sample in samples] - - if samples[0].rollout_routed_experts is not None: - train_data["rollout_routed_experts"] = [sample.rollout_routed_experts for sample in samples] - - if samples[0].train_metadata is not None: - train_data["metadata"] = [sample.train_metadata for sample in samples] - - if any(sample.multimodal_train_inputs is not None for sample in samples): - train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] - - if "teacher_log_probs" in samples[0].__dict__: - train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples] - - return train_data - - def set_train_parallel_config(self, config: dict): - self.train_parallel_config = config - - def _split_train_data_by_dp(self, data, dp_size): - """Split the train data by data parallel size.""" - rollout_data = {} - - if "prompt" in data: - rollout_data["prompt"] = data["prompt"] - - total_lengths = [len(t) for t in data["tokens"]] - data["total_lengths"] = total_lengths - - if self.args.balance_data: - partitions = get_seqlen_balanced_partitions(total_lengths, dp_size, equal_size=True) - else: - partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)] - - rollout_data_refs = [] - - for i in range(dp_size): - rollout_data = {} - partition = partitions[i] - rollout_data["partition"] = partition - for key in [ - "tokens", - "multimodal_train_inputs", - "response_lengths", - "rewards", - "truncated", - "loss_masks", - "round_number", - "sample_indices", - "rollout_log_probs", - "rollout_routed_experts", - "prompt", - "teacher_log_probs", - ]: - if key not in data: - continue - val = [data[key][j] for j in partition] - rollout_data[key] = val - # keys that need to be splited at train side - for key in [ - "raw_reward", - "total_lengths", - ]: - if key not in data: - continue - rollout_data[key] = data[key] - # Pass dynamic global_batch_size to training side - if hasattr(self, "_dynamic_global_batch_size"): - rollout_data["dynamic_global_batch_size"] = self._dynamic_global_batch_size - rollout_data_refs.append(Box(ray.put(rollout_data))) - return rollout_data_refs - + def generate(): + pass # --------------------------------------------------------------------------- # Port allocation helpers @@ -739,22 +460,19 @@ def addr(): # Router + server bootstrap # --------------------------------------------------------------------------- -def _start_router(args: SGLangConfig, *, force_new: bool = False) -> tuple[str, int]: +def _start_router(args: SGLangConfig) -> tuple[str, int]: """Start sgl router or miles router and return (router_ip, router_port). If ``args.sglang_router_ip`` is already set and ``force_new`` is False, skip launching and return the existing values. """ - if not force_new and args.sglang_router_ip is not None: + if args.sglang_router_ip is not None: return args.sglang_router_ip, args.sglang_router_port router_ip = _wrap_ipv6(get_host_info()[1]) - if force_new: + router_port = args.sglang_router_port + if router_port is None: router_port = find_available_port(random.randint(3000, 4000)) - else: - router_port = args.sglang_router_port - if router_port is None: - router_port = find_available_port(random.randint(3000, 4000)) from sglang_router.launch_router import RouterArgs @@ -781,219 +499,48 @@ def _start_router(args: SGLangConfig, *, force_new: bool = False) -> tuple[str, logger.info(f"Router launched at {router_ip}:{router_port}") return router_ip, router_port -def start_rollout_servers(args, pg) -> ServerGroup: +def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: """Start rollout servers: one per model, each with its own router. Returns a dict mapping model name -> ``RolloutServer``. """ - config = _resolve_sglang_config(args) - gpu_offset = 0 engine_offset = 0 - rollout_pg_offset = 0 - - for model_idx, model_cfg in enumerate(config.models): - model_cfg.resolve(args) - - has_pd = model_cfg.has_pd_disaggregation - router_ip, router_port = _start_router(args, has_pd_disaggregation=has_pd, force_new=(model_idx > 0)) - - if model_idx == 0: - args.sglang_router_ip = router_ip - args.sglang_router_port = router_port - - server_groups: list[ServerGroup] = [] - all_init_handles: list = [] - port_cursors: dict[int, int] = {} - - for group_cfg in model_cfg.server_groups: - gpus_per_engine = group_cfg.num_gpus_per_engine - num_gpu_per_engine_local = min(gpus_per_engine, args.num_gpus_per_node) - num_engines = group_cfg.num_gpus // num_gpu_per_engine_local - - needs_offload = args.offload_rollout - - group = ServerGroup( - args=args, - pg=pg, - all_engines=[None] * num_engines, - num_gpus_per_engine=gpus_per_engine, - num_new_engines=0, - rank_offset=engine_offset, - gpu_offset=gpu_offset, - needs_offload=needs_offload, - model_path= sglang_cfg["sglang_cfg"]["model_path"], - router_ip=router_ip, - router_port=router_port, - ) - handles, port_cursors = group.start_engines(port_cursors) - all_init_handles.extend(handles) - server_groups.append(group) - - engine_offset += num_engines - gpu_offset += group_cfg.num_gpus - - if all_init_handles: - ray.get(all_init_handles) - - servers[model_cfg.name] = RolloutServer( - server_groups=server_groups, - router_ip=router_ip, - router_port=router_port, - model_name=model_cfg.name, - update_weights=model_cfg.update_weights, - ) - - args.sglang_model_routers = {name: (srv.router_ip, srv.router_port) for name, srv in servers.items()} - - return servers - -# --------------------------------------------------------------------------- -# Logging / metrics helpers (unchanged) -# --------------------------------------------------------------------------- - - -def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] | None = None): - if args.custom_eval_rollout_log_function_path is not None: - custom_log_func = load_function(args.custom_eval_rollout_log_function_path) - if custom_log_func(rollout_id, args, data, extra_metrics): - return - - log_dict = extra_metrics or {} - for key in data.keys(): - rewards = data[key]["rewards"] - log_dict[f"eval/{key}"] = sum(rewards) / len(rewards) - if (samples := data[key].get("samples")) is not None: - log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), f"eval/{key}/") - if "truncated" in data[key]: - truncated = data[key]["truncated"] - log_dict[f"eval/{key}-truncated_ratio"] = sum(truncated) / len(truncated) - if args.log_passrate: - log_dict |= dict_add_prefix( - compute_pass_rate( - flat_rewards=rewards, - group_size=args.n_samples_per_eval_prompt, - ), - f"eval/{key}-", - ) - - logger.info(f"eval {rollout_id}: {log_dict}") - - step = compute_rollout_step(args, rollout_id) - log_dict["eval/step"] = step - tracking_utils.log(args, log_dict, step_key="eval/step") - - return log_dict - - -def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time): - if args.custom_rollout_log_function_path is not None: - custom_log_func = load_function(args.custom_rollout_log_function_path) - if custom_log_func(rollout_id, args, samples, rollout_extra_metrics, rollout_time): - return - - if args.load_debug_rollout_data: - return - - log_dict = {**(rollout_extra_metrics or {})} - log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/") - log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/") - logger.info(f"perf {rollout_id}: {log_dict}") - step = compute_rollout_step(args, rollout_id) - log_dict["rollout/step"] = step - tracking_utils.log(args, log_dict, step_key="rollout/step") - - -def compute_metrics_from_samples(args, samples): - response_lengths = [sample.effective_response_length for sample in samples] - - log_dict = {} - log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") - log_dict |= _compute_zero_std_metrics(args, samples) - log_dict |= _compute_spec_metrics(args, samples) - log_dict |= _compute_prefix_cache_metrics(args, samples) - log_dict |= _compute_reward_cat_metrics(args, samples) - log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() - log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() - return log_dict - - -def compute_perf_metrics_from_samples(args, samples, rollout_time): - non_generation_time = [sample.non_generation_time for sample in samples] - - log_dict = {} - log_dict["rollout_time"] = rollout_time - if max(non_generation_time) > 0: - log_dict |= dict_add_prefix(compute_statistics(non_generation_time), "non_generation_time/") - - def token_perf(response_lengths, non_generation_time, key=""): - max_response_length = max(response_lengths) - if args.rollout_num_gpus: - log_dict[f"{key}tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus - log_dict[f"longest_{key}sample_tokens_per_sec"] = max_response_length / rollout_time - - if max(non_generation_time) == 0: - return - - non_generation_time = [ - t for t, length in zip(non_generation_time, response_lengths, strict=True) if length == max_response_length - ] - mean_non_generation_time = sum(non_generation_time) / len(non_generation_time) - - log_dict[f"longest_{key}sample_non_generation_time"] = mean_non_generation_time - log_dict[f"longest_{key}sample_tokens_per_sec_without_non_generation"] = max_response_length / ( - rollout_time - mean_non_generation_time - ) - - token_perf([sample.response_length for sample in samples], non_generation_time, key="") - token_perf([sample.effective_response_length for sample in samples], non_generation_time, key="effective_") - - return log_dict - - -def _compute_zero_std_metrics(args, all_samples: list[Sample]): - # only compute in GRPO-like algorithms where one prompt has multiple responses - if args.advantage_estimator == "ppo": - return {} - - def _is_zero_std(samples: list[Sample]): - rewards = [sample.get_reward_value(args) for sample in samples] - return len(rewards) == 0 or all(rewards[0] == r for r in rewards) - - all_sample_groups = group_by(all_samples, lambda s: s.group_index) - interesting_sample_groups = [g for g in all_sample_groups.values() if _is_zero_std(g)] - - interesting_rewards = [str(round(g[0].get_reward_value(args), 1)) for g in interesting_sample_groups] - - return {f"zero_std/count_{reward}": len(items) for reward, items in group_by(interesting_rewards).items()} - - -def _compute_spec_metrics(args, all_samples: list[Sample]): - if args.sglang_speculative_algorithm is None: - return {} - num_samples = len(all_samples) - metrics = {} - metrics["spec_accept_rate"] = sum(sample.spec_info.spec_accept_rate for sample in all_samples) / num_samples - metrics["spec_accept_length"] = sum(sample.spec_info.spec_accept_length for sample in all_samples) / num_samples - return metrics + gpu_offset = 0 + router_ip, router_port = _start_router(sglang_cfg) -def _compute_prefix_cache_metrics(args, all_samples: list[Sample]): - num_samples = len(all_samples) - metrics = {} - total_cached_tokens = sum(sample.prefix_cache_info.cached_tokens for sample in all_samples) - total_prompt_tokens = sum(sample.prefix_cache_info.total_prompt_tokens for sample in all_samples) + sglang_cfg["sglang_router"]["sglang_router_ip"] = router_ip + sglang_cfg["sglang_router"]["sglang_router_port"] = router_port - metrics["prefix_cache_hit_rate"] = total_cached_tokens / total_prompt_tokens if total_prompt_tokens > 0 else 0.0 - metrics["avg_cached_tokens_per_sample"] = total_cached_tokens / num_samples - return metrics + all_init_handles: list = [] + port_cursors: dict[int, int] = {} + gpus_per_engine = sglang_cfg["sglang_server"]["num_gpus_per_engine"] + num_gpu_per_engine_local = min(gpus_per_engine, cluster_cfg["gpus_per_node"]) + num_engines = sglang_cfg["sglang_server"]["num_gpus"] // num_gpu_per_engine_local + needs_offload = sglang_cfg["sglang_server"]["needs_offload"] + num_gpus_per_node = cluster_cfg["gpus_per_node"] + model_path= sglang_cfg["sglang_cfg"]["model_path"] + + server_group = ServerGroup( + pg=pg, + all_engines=[None] * num_engines, + num_gpus_per_engine=gpus_per_engine, + num_gpus_per_node=num_gpus_per_node, + num_new_engines=0, + rank_offset=engine_offset, + gpu_offset=gpu_offset, + needs_offload=needs_offload, + model_path= model_path, + router_ip=router_ip, + router_port=router_port, + ) -def _compute_reward_cat_metrics(args, all_samples: list[Sample]): - reward_cat_key = args.log_reward_category - if reward_cat_key is None: - return {} + handles, port_cursors = group.start_engines(port_cursors) + all_init_handles.extend(handles) - samples_of_reward_cat = group_by(all_samples, lambda s: s.reward[reward_cat_key]) + if all_init_handles: + ray.get(all_init_handles) - return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} + return server_group diff --git a/nemo_rl/models/generation/redesign/sglang_rollout.py b/nemo_rl/models/generation/redesign/sglang_rollout.py new file mode 100644 index 0000000000..bdd5da38e4 --- /dev/null +++ b/nemo_rl/models/generation/redesign/sglang_rollout.py @@ -0,0 +1,704 @@ +import asyncio +import copy +import inspect +import logging + +from argparse import Namespace +from collections.abc import Callable +from contextlib import contextmanager +from typing import Any + +import numpy as np +import pybase64 +import sglang_router +from packaging.version import parse +from tqdm import tqdm + +from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled +from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter +from miles.utils.async_utils import run +from miles.utils.data import Dataset +from miles.utils.eval_config import EvalDatasetConfig +from miles.utils.http_utils import get, post +from miles.utils.misc import SingletonMeta, load_function +from miles.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer +from miles.utils.types import Sample + +from .rm_hub import async_rm, batched_async_rm + +__all__ = ["generate_rollout", "get_model_url"] + +logger = logging.getLogger(__name__) + + +# Create a background event loop thread +class AsyncLoopThread: + def __init__(self): + self.loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + + def _start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def run(self, coro): + # Schedule a coroutine onto the loop and block until it's done + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + + +# Create one global instance +async_loop = None + +def get_async_loop(): + global async_loop + if async_loop is None: + async_loop = AsyncLoopThread() + return async_loop + + +def run(coro): + """Run a coroutine in the background event loop.""" + return get_async_loop().run(coro) + +def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: + """Return the router URL for a named model. + + Use this in custom rollout functions to route requests to a specific + model when multiple models are deployed via ``--sglang-config``:: + + url = get_model_url(args, "ref", "/generate") + resp = await post(url, json=payload) + + Falls back to the default router if *model_name* is not found or + ``sglang_model_routers`` is not set. + """ + routers = getattr(args, "sglang_model_routers", None) + if routers and model_name in routers: + ip, port = routers[model_name] + return f"http://{ip}:{port}{endpoint}" + return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}" + + +class GenerateState(metaclass=SingletonMeta): + """ + The global state for the generation process. + """ + + def __init__(self, args: Namespace) -> None: + # persistent state for the generation process + self.args = args + self.tokenizer = load_tokenizer( + args.hf_checkpoint, chat_template_path=args.chat_template_path, trust_remote_code=True + ) + self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + + self.semaphore = asyncio.Semaphore( + args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + ) + self.sampling_params: dict[str, Any] = dict( + temperature=args.rollout_temperature, + top_p=args.rollout_top_p, + top_k=args.rollout_top_k, + max_new_tokens=args.rollout_max_response_len, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_seed_base = args.rollout_seed + self.group_sampling_seeds = [sampling_seed_base + i for i in range(args.n_samples_per_prompt)] + + # dp rank balancing + self.dp_counts = [0] * (args.sglang_dp_size or 1) + self.dp_rank = 0 + + self.reset() + + @contextmanager + def dp_rank_context(self): + candidates = [i for i, count in enumerate(self.dp_counts) if count == min(self.dp_counts)] + dp_rank = int(np.random.choice(candidates)) + self.dp_counts[dp_rank] += 1 + self.dp_rank = dp_rank + try: + yield dp_rank + finally: + self.dp_counts[dp_rank] -= 1 + assert self.dp_counts[dp_rank] >= 0 + + def reset(self) -> None: + self.remaining_batch_size = 0 + self.pendings = set() + self.aborted = False + + def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: + for group in samples: + self.pendings.add( + asyncio.create_task( + # submit a group of samples as a single task. + generate_and_rm_group( + self.args, + group, + sampling_params=self.sampling_params.copy(), + evaluation=False, + ) + ) + ) + self.remaining_batch_size += len(samples) + + +async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: + """Generate using traditional SGLang router with token-based workflow""" + if args.ci_test: + assert isinstance(sample.prompt, str) + + state = GenerateState(args) + url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + + assert ( + sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED + ), f"Sample status is {sample.status}" + + if state.processor and sample.multimodal_inputs and any(v is not None for v in sample.multimodal_inputs.values()): + processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) + prompt_ids = processor_output["input_ids"][0] + sample.multimodal_train_inputs = { + k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] + } or None + else: + prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) + + if len(sample.response) > 0: + sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) + + assert ( + sampling_params["max_new_tokens"] >= 0 + ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" + if sampling_params["max_new_tokens"] == 0: + sample.status = Sample.Status.TRUNCATED + return sample + + # Prepare payload for sglang server + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + } + + if is_lora_enabled(args): + payload["lora_path"] = LORA_ADAPTER_NAME + + if args.use_rollout_routing_replay: + payload["return_routed_experts"] = True + + if sample.multimodal_inputs and sample.multimodal_inputs["images"]: + image_data = sample.multimodal_inputs["images"] + payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] + + # Use existing tokens for multi-turn or tokenize the new prompt + if len(sample.response) > 0: + payload["input_ids"] = sample.tokens + else: + payload["input_ids"] = prompt_ids + if not sample.tokens: # Initialize sample.tokens for the first turn + sample.tokens = prompt_ids + + output = await post(url, payload) + + if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: + from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree + + sample = await postprocess_sample_with_radix_tree(args, sample, output) + else: + if "output_token_logprobs" in output["meta_info"]: + new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + else: + new_response_tokens, new_response_log_probs = [], [] + + # Update sample with tokens directly - avoiding re-tokenization + sample.tokens = sample.tokens + new_response_tokens + sample.response_length += len(new_response_tokens) + sample.response += output["text"] + + # When partial rollout and masking off policy is enabled, update the loss mask + if sample.loss_mask is not None: + assert args.partial_rollout and args.mask_offpolicy_in_partial_rollout + sample.loss_mask += [1] * len(new_response_tokens) + + if sample.rollout_log_probs is None: + sample.rollout_log_probs = [] + sample.rollout_log_probs += new_response_log_probs + + if "routed_experts" in output["meta_info"]: + sample.rollout_routed_experts = np.frombuffer( + pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), + dtype=np.int32, + ).reshape( + len(sample.tokens) - 1, + args.num_layers, + args.moe_router_topk, + ) + + sample.update_from_meta_info(args, output["meta_info"]) + + return sample + + +async def generate_and_rm( + args: Namespace, + sample: Sample | list[Sample], + sampling_params: dict[str, Any], + evaluation: bool = False, +) -> Sample | list[Sample]: + # mask previous off-policy generation for partial rollout + if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + + # For samples with existing response, check if they're complete + if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: + assert sample.response is not None + if not args.group_rm: + assert sample.reward is not None + return sample + + state = GenerateState(args) + + # generate + async with state.semaphore: + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + + with state.dp_rank_context() as _: + # Check sample.generate_function_path for per-sample custom_generate_function_path (e.g., from eval dataset config) + custom_func_path = getattr(sample, "generate_function_path", None) or args.custom_generate_function_path + + if custom_func_path is not None: + custom_generate_func = load_function(custom_func_path) + # if signature has evaluation, pass evaluation + if "evaluation" in inspect.signature(custom_generate_func).parameters: + sample = await custom_generate_func(args, sample, sampling_params, evaluation=evaluation) + else: + sample = await custom_generate_func(args, sample, sampling_params) + else: + sample = await generate(args, sample, sampling_params) + + # for the rm that need the whole group, we will not do the rm here + if args.group_rm: + return sample + + # multi samples + if isinstance(sample, list): + samples = sample + if any([sample.status == Sample.Status.ABORTED for sample in samples]): + return samples + + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + rewards = await batched_async_rm(args, samples_need_reward) + for sample, reward in zip(samples_need_reward, rewards, strict=False): + sample.reward = reward + return samples + else: + if sample.status == Sample.Status.ABORTED: + return sample + # for multi-turn environment, a reward could be assigned to the agent. + if sample.reward is None: + sample.reward = await async_rm(args, sample) + + return sample + + +async def generate_and_rm_group( + args: Namespace, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False +) -> list[Sample]: + state = GenerateState(args) + + if state.aborted: + return group + + tasks = [] + for idx, sample in enumerate(group): + current_sampling_params = sampling_params.copy() + if getattr(args, "sglang_enable_deterministic_inference", False): + seed = state.group_sampling_seeds[idx] + current_sampling_params["sampling_seed"] = seed + tasks.append( + asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) + ) + + group = await asyncio.gather(*tasks) + + # for the rm that need the whole group, we will do the rm here + if not state.aborted and args.group_rm: + rewards = await batched_async_rm(args, group) + for sample, reward in zip(group, rewards, strict=False): + sample.reward = reward + + return group + + +async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: + aborted_samples = [] + + state = GenerateState(args) + assert not state.aborted + state.aborted = True + + if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") + urls = response["urls"] + else: + response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") + urls = [worker["url"] for worker in response["workers"]] + + logger.info(f"Abort request for {urls}") + abort_tasks = [post(f"{url}/abort_request", {"abort_all": True}) for url in urls] + abort_results = await asyncio.gather(*abort_tasks, return_exceptions=True) + for url, result in zip(urls, abort_results, strict=False): + if isinstance(result, Exception): + logger.warning(f"Failed to abort worker at {url}: {result}") + + # make sure all the pending tasks are finished + count = 0 + while state.pendings: + done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) + + if not args.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for task in done: + group = task.result() + for sample in group: + if sample.response and "start_rollout_id" not in sample.metadata: + sample.metadata["start_rollout_id"] = rollout_id + aborted_samples.append(group) + count += len(group) + + if args.partial_rollout: + logger.info(f"Collected {count} partial samples into the data buffer") + + return aborted_samples + + +async def generate_rollout_async( + args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + """An example to implement the generate_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + data_source: the data source to fetch + + Returns: + tuple[RolloutFnTrainOutput, list[list[Sample]]]: + - data: a list of groups of samples generated by the rollout, length equals `rollout_batch_size` + - aborted_samples: any partial groups collected during abort when partial_rollout is enabled + """ + assert args.rollout_global_dataset + + state = GenerateState(args) + + # instantiate data filters + dynamic_filter = ( + load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None + ) + + metric_gatherer = MetricGatherer() + + # target_data_size is the total number of valid samples to get + target_data_size = args.rollout_batch_size + + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") + while len(data) < target_data_size: + while state.remaining_batch_size < target_data_size: + # get samples from the buffer and submit the generation requests. + samples = data_source(args.over_sampling_batch_size) + state.submit_generate_tasks(samples) + + # wait for the generation to finish + done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Sample] = task.result() + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + state.remaining_batch_size -= 1 + continue + + # add the samples to the data + # NOTE: here we have not stored all the unused samples back to the data buffer. + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + + pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] + logger.info( + f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", + ) + + # there are still some unfinished requests, abort them + aborted_samples = await abort(args, rollout_id) + + assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + + # reset the global state to prevent effects on the next rollout or eval. + state.reset() + if args.rollout_sample_filter_path is not None: + filter_func = load_function(args.rollout_sample_filter_path) + filter_func(args, data) + + # There can be circumstances where users want to process all samples including filtered ones. + if args.rollout_all_samples_process_path is not None: + process_func = load_function(args.rollout_all_samples_process_path) + process_func(args, all_samples, data_source) + + return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples + + +async def generate_rollout_async_stream( + args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +): + """Streaming version of generate_rollout_async that yields each completed group immediately. + + Yields: + list[Sample]: a group of samples, yielded as soon as the group finishes generation and RM. + """ + assert args.rollout_global_dataset + + state = GenerateState(args) + + dynamic_filter = ( + load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None + ) + + metric_gatherer = MetricGatherer() + target_data_size = args.rollout_batch_size + + # Submit all tasks upfront so we can use as_completed + while state.remaining_batch_size < target_data_size: + samples = data_source(args.over_sampling_batch_size) + state.submit_generate_tasks(samples) + + # Snapshot the pending tasks and iterate by completion order + all_tasks = list(state.pendings) + state.pendings = set() + + data = [] + all_data = [] + do_print = True + pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation (stream)") + + for coro in asyncio.as_completed(all_tasks): + group: list[Sample] = await coro + + if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] + logger.info( + f"First rollout sample: {[str(sample.prompt) + sample.response]}, " + f"label: {str(sample.label)[:100]}, reward: {sample.reward}", + ) + do_print = False + + assert len(group) == args.n_samples_per_prompt + all_data.append(group) + + dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) + if not dynamic_filter_output.keep: + metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) + continue + + if len(data) < target_data_size: + data.append(group) + pbar.update(args.n_samples_per_prompt) + yield group + + if len(data) >= target_data_size: + break + + pbar.close() + + # Abort remaining in-flight tasks + state.pendings = {t for t in all_tasks if not t.done()} + await abort(args, rollout_id) + + state.reset() + if args.rollout_sample_filter_path is not None: + filter_func = load_function(args.rollout_sample_filter_path) + filter_func(args, data) + + if args.rollout_all_samples_process_path is not None: + all_samples = sorted( + all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index + ) + process_func = load_function(args.rollout_all_samples_process_path) + process_func(args, all_samples, data_source) + + +EVAL_PROMPT_DATASET = {} + + +async def eval_rollout(args: Namespace, rollout_id: int) -> tuple[dict[str, dict[str, list[Any]]], list[list[Sample]]]: + assert not args.group_rm, "Group RM is not supported for eval rollout" + + coros = [] + for dataset_cfg in getattr(args, "eval_datasets", []) or []: + coros.append(eval_rollout_single_dataset(args, rollout_id, dataset_cfg)) + results_list = await asyncio.gather(*coros) + results = {} + for r in results_list: + results.update(r) + return RolloutFnEvalOutput(data=results), [] + + +async def eval_rollout_single_dataset( + args: Namespace, rollout_id: int, dataset_cfg: EvalDatasetConfig +) -> dict[str, dict[str, list[Any]]]: + """An example to implement the eval_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + dataset_cfg: configuration of the dataset + """ + assert not args.group_rm, "Group RM is not supported for eval rollout" + + global EVAL_PROMPT_DATASET + + cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template, args.chat_template_path) + if cache_key not in EVAL_PROMPT_DATASET: + tokenizer = load_tokenizer( + args.hf_checkpoint, chat_template_path=args.chat_template_path, trust_remote_code=True + ) + processor = load_processor(args.hf_checkpoint, trust_remote_code=True) + EVAL_PROMPT_DATASET[cache_key] = Dataset( + path=dataset_cfg.path, + tokenizer=tokenizer, + processor=processor, + max_length=args.eval_max_prompt_len, + prompt_key=dataset_cfg.input_key, + label_key=dataset_cfg.label_key, + multimodal_keys=args.multimodal_keys, + metadata_key=dataset_cfg.metadata_key, + tool_key=dataset_cfg.tool_key, + apply_chat_template=args.apply_chat_template, + apply_chat_template_kwargs=args.apply_chat_template_kwargs, + ) + dataset = EVAL_PROMPT_DATASET[cache_key] + + base_sampling_params = dict( + temperature=dataset_cfg.temperature, + top_p=dataset_cfg.top_p, + top_k=dataset_cfg.top_k, + max_new_tokens=dataset_cfg.max_response_len, + stop=args.rollout_stop, + stop_token_ids=args.rollout_stop_token_ids, + skip_special_tokens=args.rollout_skip_special_tokens, + no_stop_trim=True, + spaces_between_special_tokens=False, + ) + + tasks = [] + # do multiple samples for eval prompts + sample_index = 0 + for _i, prompt_sample in enumerate(dataset.samples): + for j in range(dataset_cfg.n_samples_per_eval_prompt): + # use the same prompt for multiple samples + sample = copy.deepcopy(prompt_sample) + sample.index = sample_index + sample_index += 1 + sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) + sample.generate_function_path = getattr(dataset_cfg, "custom_generate_function_path", None) + sampling_params = base_sampling_params + if getattr(args, "sglang_enable_deterministic_inference", False): + sampling_params = base_sampling_params.copy() + sampling_params["sampling_seed"] = args.rollout_seed + j + tasks.append( + asyncio.create_task( + generate_and_rm( + args, + sample, + sampling_params=sampling_params, + evaluation=True, + ) + ) + ) + + data = [] + do_print = True + pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) + for coro in asyncio.as_completed(tasks): + sample = await coro + if do_print: + logger.info( + "eval_rollout_single_dataset example data: " + f"{[str(sample.prompt) + sample.response]} " + f"reward={sample.reward}" + ) + do_print = False + if isinstance(sample, list): + data.extend(sample) + else: + data.append(sample) + pbar.update(1) + pbar.close() + + data.sort(key=lambda sample: sample.index) + + reward_key = args.eval_reward_key or args.reward_key + return { + dataset_cfg.name: { + "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], + "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], + "samples": data, + } + } + + +def generate_rollout( + args: Namespace, rollout_id: int, data_source: Any, evaluation: bool = False +) -> RolloutFnTrainOutput | RolloutFnEvalOutput: + """An example to implement the generate_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + data_buffer: the data buffer to store the generated samples + evaluation: bool, whether the rollout is for evaluation or not + + Returns: + list[list[Sample]]: a list of list of samples generated by the rollout + """ + assert args.rollout_global_dataset + if evaluation: + output, _ = run(eval_rollout(args, rollout_id)) + return output + + output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) + data_source.add_samples(aborted_samples) + return output From 0f1639d231ebf88bfaf0d6506619033853761089 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Thu, 9 Apr 2026 21:04:57 -0700 Subject: [PATCH 08/36] update --- .../generation/redesign/sglang_generation.py | 273 ++++++- .../generation/redesign/sglang_rollout.py | 704 ------------------ .../generation/redesign/sglang_worker.py | 2 +- nemo_rl/models/generation/redesign/utils.py | 4 +- 4 files changed, 274 insertions(+), 709 deletions(-) delete mode 100644 nemo_rl/models/generation/redesign/sglang_rollout.py diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 8239521d26..1f82c1603d 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -70,6 +70,60 @@ # pg_reordered_bundle_indices = cluster._get_sorted_bundle_indices() +class AsyncLoopThread: + def __init__(self): + self.loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + + def _start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def run(self, coro): + # Schedule a coroutine onto the loop and block until it's done + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + +# Create one global instance +async_loop = None + +def get_async_loop(): + global async_loop + if async_loop is None: + async_loop = AsyncLoopThread() + return async_loop + +def run(coro): + """Run a coroutine in the background event loop.""" + return get_async_loop().run(coro) + + +async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_params, input_ids): + """Generate using traditional SGLang router with token-based workflow""" + url = f"http://{sglang_router_ip}:{sglang_router_port}/generate" + + # Prepare payload for sglang server + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + "input_ids": input_ids, + } + + output = await post(url, payload) + + if "output_token_logprobs" in output["meta_info"]: + response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + else: + response_tokens, response_log_probs = [], [] + + response_truncated = False + if response_truncated["meta_info"]["output_token_logprobs"] is not None and output["meta_info"]["output_token_logprobs"] == "length": + response_truncated = True + + return response_tokens, response_log_probs, response_truncated + + @dataclasses.dataclass class ServerGroup: """A group of homogeneous SGLang engines with the same configuration. @@ -361,9 +415,222 @@ def clear_updatable_num_new_engines(self): def check_weights(self, action: str): return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) - def generate(): - pass + def _merge_stop_strings(self, batch_stop_strings) -> list[list[str]]: + """Merge stop strings from config and batch. + + Args: + batch_stop_strings: List of stop strings from batch (one per sample) + + Returns: + List of merged stop strings (one per sample) + """ + stop_set: set[str] = set() + + # Add stop strings from config + if self.sglang_cfg.get("stop_strings"): + stop_set.update(self.sglang_cfg["stop_strings"]) + + # Merge stop strings from batch + merged_stop_strings = [] + for sample_ss in batch_stop_strings: + sample_stop_set = stop_set.copy() + if sample_ss: + if isinstance(sample_ss, str): + sample_stop_set.add(sample_ss) + elif isinstance(sample_ss, list): + sample_stop_set.update(sample_ss) + + merged_stop_strings.append(list(sample_stop_set)) + + return merged_stop_strings + + def _build_sampling_params( + self, + *, + greedy: bool, + stop_strings: Optinoal[list[str]] = None, + input_len: Optional[int] = None, + ) -> dict[str, Any]: + """Build sampling parameters dictionary for SGLang API. + + Args: + greedy: Whether to use greedy decoding (temperature=0.0) + stop_strings: Merged stop strings (not used here, handled per sample) + max_new_tokens: Override max_new_tokens from config if provided + input_len: Input length for this sample (used for context_length adjustment) + + Returns: + Dictionary of sampling parameters compatible with SGLang API + """ + top_k_cfg = self.sglang_cfg.get("top_k") + top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1) + if top_k_val != -1: + sampling_params["top_k"] = top_k_val + temperature = 0.0 if greedy else self.sglang_cfg["temperature"] + + context_length = self.sglang_cfg["sglang_cfg"]["context_length"] + if context_length is not None: + max_new_tokens = min(self.sglang_cfg["max_new_tokens"], context_length - input_len) + else: + max_new_tokens = self.sglang_cfg["max_new_tokens"] + + if max_new_tokens < 0: + raise("context len is smaller than input len") + + # Build sampling params dict + sampling_params = { + "temperature": temperature, + "top_p": self.sglang_cfg.get("top_p", 1.0), + "max_new_tokens": max_new_tokens, + "no_stop_trim": True, + "spaces_between_special_tokens": False, + } + + stop_token_ids = self.sglang_cfg.get("stop_token_ids") + if stop_token_ids is not None: + sampling_params["stop_token_ids"] = stop_token_ids + + if stop_strings is not None and len(stop_strings) > 0: + sampling_params["stop"] = stop_strings + + return sampling_params + + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using vLLM generation. + + Args: + data: BatchedDataDict containing input_ids and input_lengths tensors + greedy: Whether to use greedy decoding instead of sampling + + Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs with proper padding + - logprobs: Log probabilities for tokens + - generation_lengths: Lengths of each response + - unpadded_sequence_lengths: Lengths of each input + generated sequence + """ + # Handle empty input case + if len(data["input_ids"]) == 0: + # Return empty BatchedDataDict with all required fields + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": torch.zeros((0, 0), dtype=torch.long), + "logprobs": torch.zeros((0, 0), dtype=torch.float), + "generation_lengths": torch.zeros(0, dtype=torch.long), + "unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long), + "truncated": torch.zeros(0, dtype=torch.bool), + } + ) + + input_ids = data["input_ids"] + input_lengths = data["input_lengths"] + batch_stop_strings: list[list[str]] = data.get("stop_strings", []) + stop_strings = self._merge_stop_strings(batch_stop_strings) + + batch_size = len(input_lengths) + padded_input_length = input_ids.size(1) + + # verify inputs have correct padding + verify_right_padding(data, pad_value=self.sglang_cfg["_pad_token_id"]) + + # Original input length with padding + for i in range(batch_size): + input_len = input_lengths[i].item() + + valid_input_ids = input_ids[i, :input_len].tolist() + + # Build sampling params for this sample (with context_length adjustment) + sample_sampling_params = self._build_sampling_params( + greedy=greedy, + stop_strings=stop_strings, + input_len=input_len, + ) + + + # TODO: + # Process the outputs - but preserve the original input padding structure + output_ids_list = [] + logprobs_list = [] + generation_lengths = [] + unpadded_sequence_lengths = [] + truncated_list = [] # Track if response was truncated (hit max_tokens) + max_length = 0 + for output in outputs: + max_length = max(max_length, len(output.outputs[0].token_ids)) + + for i, output in enumerate(outputs): + # Extract generated tokens + sequence_length = input_lengths[i] + generation = output.outputs[0] + generated_tokens = list(generation.token_ids) + + # Calculate total sequence length (original input length + generated tokens) + total_length = padded_input_length + max_length + + # Create a new tensor with the right size and fill with padding token + full_output = torch.full( + (total_length,), self.cfg["_pad_token_id"], dtype=input_ids.dtype + ) + + # Copy original input (with padding) into the beginning + full_output[:sequence_length] = input_ids[i][:sequence_length] + + # Add generated tokens after the original input + full_output[sequence_length : sequence_length + len(generated_tokens)] = ( + torch.tensor(generated_tokens) + ) + + output_ids_list.append(full_output) + full_logprobs = torch.zeros(total_length, dtype=torch.float32) + if hasattr(generation, "logprobs") and generation.logprobs: + try: + for idx, logprob_dict in enumerate(generation.logprobs): + if logprob_dict: + position = sequence_length + idx + full_logprobs[position] = next(iter(logprob_dict.items()))[ + 1 + ].logprob + except Exception: + import traceback + + traceback.print_exc() + + logprobs_list.append(full_logprobs) + + response_length = sequence_length + len(generated_tokens) + generation_lengths.append(len(generated_tokens)) + unpadded_sequence_lengths.append(response_length) + + # Check if response was truncated (hit max_tokens length limit) + is_truncated = generation.finish_reason == "length" + truncated_list.append(is_truncated) + + assert response_length <= self.llm.llm_engine.model_config.max_model_len, ( + f"response_length={response_length} > max_model_len={self.llm.llm_engine.model_config.max_model_len}, which should not happen. Please check this behavior in isolation by running `uv run --extra vllm tools/model_diagnostics/1.max_model_len_respected.py {self.llm.llm_engine.model_config.model}` and raise this issue with the vllm team." + ) + + # Create return data conforming to GenerationOutputSpec + output_ids = torch.stack(output_ids_list) + logprobs = torch.stack(logprobs_list) + + return_data = BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": output_ids, + "logprobs": logprobs, + "generation_lengths": torch.tensor( + generation_lengths, dtype=torch.long + ), + "unpadded_sequence_lengths": torch.tensor( + unpadded_sequence_lengths, dtype=torch.long + ), + "truncated": torch.tensor(truncated_list, dtype=torch.bool), + } + ) + return return_data # --------------------------------------------------------------------------- # Port allocation helpers # --------------------------------------------------------------------------- @@ -537,7 +804,7 @@ def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: router_port=router_port, ) - handles, port_cursors = group.start_engines(port_cursors) + handles, port_cursors = server_group.start_engines(port_cursors) all_init_handles.extend(handles) if all_init_handles: diff --git a/nemo_rl/models/generation/redesign/sglang_rollout.py b/nemo_rl/models/generation/redesign/sglang_rollout.py deleted file mode 100644 index bdd5da38e4..0000000000 --- a/nemo_rl/models/generation/redesign/sglang_rollout.py +++ /dev/null @@ -1,704 +0,0 @@ -import asyncio -import copy -import inspect -import logging - -from argparse import Namespace -from collections.abc import Callable -from contextlib import contextmanager -from typing import Any - -import numpy as np -import pybase64 -import sglang_router -from packaging.version import parse -from tqdm import tqdm - -from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled -from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput -from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter -from miles.utils.async_utils import run -from miles.utils.data import Dataset -from miles.utils.eval_config import EvalDatasetConfig -from miles.utils.http_utils import get, post -from miles.utils.misc import SingletonMeta, load_function -from miles.utils.processing_utils import encode_image_for_rollout_engine, load_processor, load_tokenizer -from miles.utils.types import Sample - -from .rm_hub import async_rm, batched_async_rm - -__all__ = ["generate_rollout", "get_model_url"] - -logger = logging.getLogger(__name__) - - -# Create a background event loop thread -class AsyncLoopThread: - def __init__(self): - self.loop = asyncio.new_event_loop() - self._thread = threading.Thread(target=self._start_loop, daemon=True) - self._thread.start() - - def _start_loop(self): - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def run(self, coro): - # Schedule a coroutine onto the loop and block until it's done - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() - - -# Create one global instance -async_loop = None - -def get_async_loop(): - global async_loop - if async_loop is None: - async_loop = AsyncLoopThread() - return async_loop - - -def run(coro): - """Run a coroutine in the background event loop.""" - return get_async_loop().run(coro) - -def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: - """Return the router URL for a named model. - - Use this in custom rollout functions to route requests to a specific - model when multiple models are deployed via ``--sglang-config``:: - - url = get_model_url(args, "ref", "/generate") - resp = await post(url, json=payload) - - Falls back to the default router if *model_name* is not found or - ``sglang_model_routers`` is not set. - """ - routers = getattr(args, "sglang_model_routers", None) - if routers and model_name in routers: - ip, port = routers[model_name] - return f"http://{ip}:{port}{endpoint}" - return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}" - - -class GenerateState(metaclass=SingletonMeta): - """ - The global state for the generation process. - """ - - def __init__(self, args: Namespace) -> None: - # persistent state for the generation process - self.args = args - self.tokenizer = load_tokenizer( - args.hf_checkpoint, chat_template_path=args.chat_template_path, trust_remote_code=True - ) - self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - - self.semaphore = asyncio.Semaphore( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine - ) - self.sampling_params: dict[str, Any] = dict( - temperature=args.rollout_temperature, - top_p=args.rollout_top_p, - top_k=args.rollout_top_k, - max_new_tokens=args.rollout_max_response_len, - stop=args.rollout_stop, - stop_token_ids=args.rollout_stop_token_ids, - skip_special_tokens=args.rollout_skip_special_tokens, - no_stop_trim=True, - spaces_between_special_tokens=False, - ) - - if getattr(args, "sglang_enable_deterministic_inference", False): - sampling_seed_base = args.rollout_seed - self.group_sampling_seeds = [sampling_seed_base + i for i in range(args.n_samples_per_prompt)] - - # dp rank balancing - self.dp_counts = [0] * (args.sglang_dp_size or 1) - self.dp_rank = 0 - - self.reset() - - @contextmanager - def dp_rank_context(self): - candidates = [i for i, count in enumerate(self.dp_counts) if count == min(self.dp_counts)] - dp_rank = int(np.random.choice(candidates)) - self.dp_counts[dp_rank] += 1 - self.dp_rank = dp_rank - try: - yield dp_rank - finally: - self.dp_counts[dp_rank] -= 1 - assert self.dp_counts[dp_rank] >= 0 - - def reset(self) -> None: - self.remaining_batch_size = 0 - self.pendings = set() - self.aborted = False - - def submit_generate_tasks(self, samples: list[list[Sample]]) -> None: - for group in samples: - self.pendings.add( - asyncio.create_task( - # submit a group of samples as a single task. - generate_and_rm_group( - self.args, - group, - sampling_params=self.sampling_params.copy(), - evaluation=False, - ) - ) - ) - self.remaining_batch_size += len(samples) - - -async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: - """Generate using traditional SGLang router with token-based workflow""" - if args.ci_test: - assert isinstance(sample.prompt, str) - - state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - - assert ( - sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED - ), f"Sample status is {sample.status}" - - if state.processor and sample.multimodal_inputs and any(v is not None for v in sample.multimodal_inputs.values()): - processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs) - prompt_ids = processor_output["input_ids"][0] - sample.multimodal_train_inputs = { - k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"] - } or None - else: - prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False) - - if len(sample.response) > 0: - sampling_params["max_new_tokens"] -= len(sample.tokens) - len(prompt_ids) - - assert ( - sampling_params["max_new_tokens"] >= 0 - ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0" - if sampling_params["max_new_tokens"] == 0: - sample.status = Sample.Status.TRUNCATED - return sample - - # Prepare payload for sglang server - payload = { - "sampling_params": sampling_params, - "return_logprob": True, - } - - if is_lora_enabled(args): - payload["lora_path"] = LORA_ADAPTER_NAME - - if args.use_rollout_routing_replay: - payload["return_routed_experts"] = True - - if sample.multimodal_inputs and sample.multimodal_inputs["images"]: - image_data = sample.multimodal_inputs["images"] - payload["image_data"] = [encode_image_for_rollout_engine(image) for image in image_data] - - # Use existing tokens for multi-turn or tokenize the new prompt - if len(sample.response) > 0: - payload["input_ids"] = sample.tokens - else: - payload["input_ids"] = prompt_ids - if not sample.tokens: # Initialize sample.tokens for the first turn - sample.tokens = prompt_ids - - output = await post(url, payload) - - if args.use_miles_router and "RadixTreeMiddleware" in args.miles_router_middleware_paths: - from miles.router.middleware_hub.radix_tree_middleware import postprocess_sample_with_radix_tree - - sample = await postprocess_sample_with_radix_tree(args, sample, output) - else: - if "output_token_logprobs" in output["meta_info"]: - new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - else: - new_response_tokens, new_response_log_probs = [], [] - - # Update sample with tokens directly - avoiding re-tokenization - sample.tokens = sample.tokens + new_response_tokens - sample.response_length += len(new_response_tokens) - sample.response += output["text"] - - # When partial rollout and masking off policy is enabled, update the loss mask - if sample.loss_mask is not None: - assert args.partial_rollout and args.mask_offpolicy_in_partial_rollout - sample.loss_mask += [1] * len(new_response_tokens) - - if sample.rollout_log_probs is None: - sample.rollout_log_probs = [] - sample.rollout_log_probs += new_response_log_probs - - if "routed_experts" in output["meta_info"]: - sample.rollout_routed_experts = np.frombuffer( - pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), - dtype=np.int32, - ).reshape( - len(sample.tokens) - 1, - args.num_layers, - args.moe_router_topk, - ) - - sample.update_from_meta_info(args, output["meta_info"]) - - return sample - - -async def generate_and_rm( - args: Namespace, - sample: Sample | list[Sample], - sampling_params: dict[str, Any], - evaluation: bool = False, -) -> Sample | list[Sample]: - # mask previous off-policy generation for partial rollout - if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: - sample.loss_mask = [0] * sample.response_length - - # For samples with existing response, check if they're complete - if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED: - assert sample.response is not None - if not args.group_rm: - assert sample.reward is not None - return sample - - state = GenerateState(args) - - # generate - async with state.semaphore: - if state.aborted: - sample.status = Sample.Status.ABORTED - return sample - - with state.dp_rank_context() as _: - # Check sample.generate_function_path for per-sample custom_generate_function_path (e.g., from eval dataset config) - custom_func_path = getattr(sample, "generate_function_path", None) or args.custom_generate_function_path - - if custom_func_path is not None: - custom_generate_func = load_function(custom_func_path) - # if signature has evaluation, pass evaluation - if "evaluation" in inspect.signature(custom_generate_func).parameters: - sample = await custom_generate_func(args, sample, sampling_params, evaluation=evaluation) - else: - sample = await custom_generate_func(args, sample, sampling_params) - else: - sample = await generate(args, sample, sampling_params) - - # for the rm that need the whole group, we will not do the rm here - if args.group_rm: - return sample - - # multi samples - if isinstance(sample, list): - samples = sample - if any([sample.status == Sample.Status.ABORTED for sample in samples]): - return samples - - # for multi agent system, the reward of some sample is calculated during generation. - samples_need_reward = [sample for sample in samples if sample.reward is None] - rewards = await batched_async_rm(args, samples_need_reward) - for sample, reward in zip(samples_need_reward, rewards, strict=False): - sample.reward = reward - return samples - else: - if sample.status == Sample.Status.ABORTED: - return sample - # for multi-turn environment, a reward could be assigned to the agent. - if sample.reward is None: - sample.reward = await async_rm(args, sample) - - return sample - - -async def generate_and_rm_group( - args: Namespace, group: list[Sample], sampling_params: dict[str, Any], evaluation: bool = False -) -> list[Sample]: - state = GenerateState(args) - - if state.aborted: - return group - - tasks = [] - for idx, sample in enumerate(group): - current_sampling_params = sampling_params.copy() - if getattr(args, "sglang_enable_deterministic_inference", False): - seed = state.group_sampling_seeds[idx] - current_sampling_params["sampling_seed"] = seed - tasks.append( - asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) - ) - - group = await asyncio.gather(*tasks) - - # for the rm that need the whole group, we will do the rm here - if not state.aborted and args.group_rm: - rewards = await batched_async_rm(args, group) - for sample, reward in zip(group, rewards, strict=False): - sample.reward = reward - - return group - - -async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: - aborted_samples = [] - - state = GenerateState(args) - assert not state.aborted - state.aborted = True - - if parse(sglang_router.__version__) <= parse("0.2.1") or args.use_miles_router: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") - urls = response["urls"] - else: - response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/workers") - urls = [worker["url"] for worker in response["workers"]] - - logger.info(f"Abort request for {urls}") - abort_tasks = [post(f"{url}/abort_request", {"abort_all": True}) for url in urls] - abort_results = await asyncio.gather(*abort_tasks, return_exceptions=True) - for url, result in zip(urls, abort_results, strict=False): - if isinstance(result, Exception): - logger.warning(f"Failed to abort worker at {url}: {result}") - - # make sure all the pending tasks are finished - count = 0 - while state.pendings: - done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) - - if not args.partial_rollout: - continue - - # for partial rollout, collect the partial samples into the data buffer - for task in done: - group = task.result() - for sample in group: - if sample.response and "start_rollout_id" not in sample.metadata: - sample.metadata["start_rollout_id"] = rollout_id - aborted_samples.append(group) - count += len(group) - - if args.partial_rollout: - logger.info(f"Collected {count} partial samples into the data buffer") - - return aborted_samples - - -async def generate_rollout_async( - args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] -) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: - """An example to implement the generate_rollout function for an rule based rm rollout generation. - - Args: - args: the whole args - rollout_id: int, the id of the rollout, used for deterministic data generation - data_source: the data source to fetch - - Returns: - tuple[RolloutFnTrainOutput, list[list[Sample]]]: - - data: a list of groups of samples generated by the rollout, length equals `rollout_batch_size` - - aborted_samples: any partial groups collected during abort when partial_rollout is enabled - """ - assert args.rollout_global_dataset - - state = GenerateState(args) - - # instantiate data filters - dynamic_filter = ( - load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None - ) - - metric_gatherer = MetricGatherer() - - # target_data_size is the total number of valid samples to get - target_data_size = args.rollout_batch_size - - data = [] - all_data = [] - do_print = True - pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation") - while len(data) < target_data_size: - while state.remaining_batch_size < target_data_size: - # get samples from the buffer and submit the generation requests. - samples = data_source(args.over_sampling_batch_size) - state.submit_generate_tasks(samples) - - # wait for the generation to finish - done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED) - for task in done: - group: list[Sample] = task.result() - - if do_print: - sample = group[0][0] if isinstance(group[0], list) else group[0] - logger.info( - f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", - ) - do_print = False - - assert len(group) == args.n_samples_per_prompt - all_data.append(group) - dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) - if not dynamic_filter_output.keep: - metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) - state.remaining_batch_size -= 1 - continue - - # add the samples to the data - # NOTE: here we have not stored all the unused samples back to the data buffer. - if len(data) < target_data_size: - data.append(group) - pbar.update(args.n_samples_per_prompt) - - pbar.close() - sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] - logger.info( - f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {str(sample.label)[:100]}, reward: {sample.reward}", - ) - - # there are still some unfinished requests, abort them - aborted_samples = await abort(args, rollout_id) - - assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" - data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) - all_samples = sorted( - all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index - ) - - # reset the global state to prevent effects on the next rollout or eval. - state.reset() - if args.rollout_sample_filter_path is not None: - filter_func = load_function(args.rollout_sample_filter_path) - filter_func(args, data) - - # There can be circumstances where users want to process all samples including filtered ones. - if args.rollout_all_samples_process_path is not None: - process_func = load_function(args.rollout_all_samples_process_path) - process_func(args, all_samples, data_source) - - return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples - - -async def generate_rollout_async_stream( - args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] -): - """Streaming version of generate_rollout_async that yields each completed group immediately. - - Yields: - list[Sample]: a group of samples, yielded as soon as the group finishes generation and RM. - """ - assert args.rollout_global_dataset - - state = GenerateState(args) - - dynamic_filter = ( - load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None - ) - - metric_gatherer = MetricGatherer() - target_data_size = args.rollout_batch_size - - # Submit all tasks upfront so we can use as_completed - while state.remaining_batch_size < target_data_size: - samples = data_source(args.over_sampling_batch_size) - state.submit_generate_tasks(samples) - - # Snapshot the pending tasks and iterate by completion order - all_tasks = list(state.pendings) - state.pendings = set() - - data = [] - all_data = [] - do_print = True - pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation (stream)") - - for coro in asyncio.as_completed(all_tasks): - group: list[Sample] = await coro - - if do_print: - sample = group[0][0] if isinstance(group[0], list) else group[0] - logger.info( - f"First rollout sample: {[str(sample.prompt) + sample.response]}, " - f"label: {str(sample.label)[:100]}, reward: {sample.reward}", - ) - do_print = False - - assert len(group) == args.n_samples_per_prompt - all_data.append(group) - - dynamic_filter_output = call_dynamic_filter(dynamic_filter, args, group) - if not dynamic_filter_output.keep: - metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason) - continue - - if len(data) < target_data_size: - data.append(group) - pbar.update(args.n_samples_per_prompt) - yield group - - if len(data) >= target_data_size: - break - - pbar.close() - - # Abort remaining in-flight tasks - state.pendings = {t for t in all_tasks if not t.done()} - await abort(args, rollout_id) - - state.reset() - if args.rollout_sample_filter_path is not None: - filter_func = load_function(args.rollout_sample_filter_path) - filter_func(args, data) - - if args.rollout_all_samples_process_path is not None: - all_samples = sorted( - all_data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index - ) - process_func = load_function(args.rollout_all_samples_process_path) - process_func(args, all_samples, data_source) - - -EVAL_PROMPT_DATASET = {} - - -async def eval_rollout(args: Namespace, rollout_id: int) -> tuple[dict[str, dict[str, list[Any]]], list[list[Sample]]]: - assert not args.group_rm, "Group RM is not supported for eval rollout" - - coros = [] - for dataset_cfg in getattr(args, "eval_datasets", []) or []: - coros.append(eval_rollout_single_dataset(args, rollout_id, dataset_cfg)) - results_list = await asyncio.gather(*coros) - results = {} - for r in results_list: - results.update(r) - return RolloutFnEvalOutput(data=results), [] - - -async def eval_rollout_single_dataset( - args: Namespace, rollout_id: int, dataset_cfg: EvalDatasetConfig -) -> dict[str, dict[str, list[Any]]]: - """An example to implement the eval_rollout function for an rule based rm rollout generation. - - Args: - args: the whole args - rollout_id: int, the id of the rollout, used for deterministic data generation - dataset_cfg: configuration of the dataset - """ - assert not args.group_rm, "Group RM is not supported for eval rollout" - - global EVAL_PROMPT_DATASET - - cache_key = dataset_cfg.cache_key + (args.hf_checkpoint, args.apply_chat_template, args.chat_template_path) - if cache_key not in EVAL_PROMPT_DATASET: - tokenizer = load_tokenizer( - args.hf_checkpoint, chat_template_path=args.chat_template_path, trust_remote_code=True - ) - processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - EVAL_PROMPT_DATASET[cache_key] = Dataset( - path=dataset_cfg.path, - tokenizer=tokenizer, - processor=processor, - max_length=args.eval_max_prompt_len, - prompt_key=dataset_cfg.input_key, - label_key=dataset_cfg.label_key, - multimodal_keys=args.multimodal_keys, - metadata_key=dataset_cfg.metadata_key, - tool_key=dataset_cfg.tool_key, - apply_chat_template=args.apply_chat_template, - apply_chat_template_kwargs=args.apply_chat_template_kwargs, - ) - dataset = EVAL_PROMPT_DATASET[cache_key] - - base_sampling_params = dict( - temperature=dataset_cfg.temperature, - top_p=dataset_cfg.top_p, - top_k=dataset_cfg.top_k, - max_new_tokens=dataset_cfg.max_response_len, - stop=args.rollout_stop, - stop_token_ids=args.rollout_stop_token_ids, - skip_special_tokens=args.rollout_skip_special_tokens, - no_stop_trim=True, - spaces_between_special_tokens=False, - ) - - tasks = [] - # do multiple samples for eval prompts - sample_index = 0 - for _i, prompt_sample in enumerate(dataset.samples): - for j in range(dataset_cfg.n_samples_per_eval_prompt): - # use the same prompt for multiple samples - sample = copy.deepcopy(prompt_sample) - sample.index = sample_index - sample_index += 1 - sample.metadata = dataset_cfg.inject_metadata(getattr(sample, "metadata", None)) - sample.generate_function_path = getattr(dataset_cfg, "custom_generate_function_path", None) - sampling_params = base_sampling_params - if getattr(args, "sglang_enable_deterministic_inference", False): - sampling_params = base_sampling_params.copy() - sampling_params["sampling_seed"] = args.rollout_seed + j - tasks.append( - asyncio.create_task( - generate_and_rm( - args, - sample, - sampling_params=sampling_params, - evaluation=True, - ) - ) - ) - - data = [] - do_print = True - pbar = tqdm(total=len(tasks), desc=f"Eval {dataset_cfg.name}", disable=not do_print) - for coro in asyncio.as_completed(tasks): - sample = await coro - if do_print: - logger.info( - "eval_rollout_single_dataset example data: " - f"{[str(sample.prompt) + sample.response]} " - f"reward={sample.reward}" - ) - do_print = False - if isinstance(sample, list): - data.extend(sample) - else: - data.append(sample) - pbar.update(1) - pbar.close() - - data.sort(key=lambda sample: sample.index) - - reward_key = args.eval_reward_key or args.reward_key - return { - dataset_cfg.name: { - "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], - "truncated": [sample.status == Sample.Status.TRUNCATED for sample in data], - "samples": data, - } - } - - -def generate_rollout( - args: Namespace, rollout_id: int, data_source: Any, evaluation: bool = False -) -> RolloutFnTrainOutput | RolloutFnEvalOutput: - """An example to implement the generate_rollout function for an rule based rm rollout generation. - - Args: - args: the whole args - rollout_id: int, the id of the rollout, used for deterministic data generation - data_buffer: the data buffer to store the generated samples - evaluation: bool, whether the rollout is for evaluation or not - - Returns: - list[list[Sample]]: a list of list of samples generated by the rollout - """ - assert args.rollout_global_dataset - if evaluation: - output, _ = run(eval_rollout(args, rollout_id)) - return output - - output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) - data_source.add_samples(aborted_samples) - return output diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 7ccd708896..269e9539ef 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -500,7 +500,7 @@ def _compute_server_args( for key in [ "dtype", "kv_cache_dtype", - "context_length", + " ", "max_running_requests", "chunked_prefill_size", "max_prefill_tokens", diff --git a/nemo_rl/models/generation/redesign/utils.py b/nemo_rl/models/generation/redesign/utils.py index 64457f64fd..2a781f5761 100644 --- a/nemo_rl/models/generation/redesign/utils.py +++ b/nemo_rl/models/generation/redesign/utils.py @@ -1,4 +1,6 @@ + + def find_available_port(base_port: int): port = base_port + random.randint(100, 1000) while True: @@ -211,7 +213,7 @@ def init_http_client(args: SglangSpecificArgs): _distributed_post_enabled = True -def _init_ray_distributed_post(args): +def _init_ray_distributed_post(args: SglangSpecificArgs): """Initialize one or more Ray async actors per node for HTTP POST. Uses NodeAffinitySchedulingStrategy to place actors on distinct nodes. From 1653e72a9065918155898785929c2b91c05289f2 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 12:42:28 -0700 Subject: [PATCH 09/36] update --- .../generation/redesign/sglang_generation.py | 329 ++++++++++++++---- 1 file changed, 253 insertions(+), 76 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 1f82c1603d..00173d51e9 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -1,12 +1,14 @@ +import asyncio import dataclasses import itertools import logging import multiprocessing import os import random +import threading import time from pathlib import Path -from typing import Any +from typing import Any, AsyncGenerator import numpy as np import ray @@ -98,7 +100,7 @@ def run(coro): return get_async_loop().run(coro) -async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_params, input_ids): +async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_params, input_ids, index: int): """Generate using traditional SGLang router with token-based workflow""" url = f"http://{sglang_router_ip}:{sglang_router_port}/generate" @@ -121,7 +123,7 @@ async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_par if response_truncated["meta_info"]["output_token_logprobs"] is not None and output["meta_info"]["output_token_logprobs"] == "length": response_truncated = True - return response_tokens, response_log_probs, response_truncated + return index, response_tokens, response_log_probs, response_truncated @dataclasses.dataclass @@ -448,6 +450,7 @@ def _build_sampling_params( self, *, greedy: bool, + max_new_tokens: int, stop_strings: Optinoal[list[str]] = None, input_len: Optional[int] = None, ) -> dict[str, Any]: @@ -468,15 +471,6 @@ def _build_sampling_params( sampling_params["top_k"] = top_k_val temperature = 0.0 if greedy else self.sglang_cfg["temperature"] - context_length = self.sglang_cfg["sglang_cfg"]["context_length"] - if context_length is not None: - max_new_tokens = min(self.sglang_cfg["max_new_tokens"], context_length - input_len) - else: - max_new_tokens = self.sglang_cfg["max_new_tokens"] - - if max_new_tokens < 0: - raise("context len is smaller than input len") - # Build sampling params dict sampling_params = { "temperature": temperature, @@ -532,105 +526,288 @@ def generate( batch_size = len(input_lengths) padded_input_length = input_ids.size(1) + context_length = self.sglang_cfg["sglang_cfg"]["context_length"] # verify inputs have correct padding verify_right_padding(data, pad_value=self.sglang_cfg["_pad_token_id"]) - # Original input length with padding + # Build per-sample requests (each sample gets its own sampling params because + # max_new_tokens is adjusted against the per-sample input length). + sample_requests: list[tuple[int, dict[str, Any], list[int]]] = [] + skip_results = [] for i in range(batch_size): input_len = input_lengths[i].item() - valid_input_ids = input_ids[i, :input_len].tolist() - # Build sampling params for this sample (with context_length adjustment) - sample_sampling_params = self._build_sampling_params( - greedy=greedy, - stop_strings=stop_strings, - input_len=input_len, - ) + if context_length is not None: + max_new_tokens = min(self.sglang_cfg["max_new_tokens"], context_length - input_len) + else: + max_new_tokens = self.sglang_cfg["max_new_tokens"] + max_new_tokens = max(0, max_new_tokens) + + if max_new_tokens == 0: + skip_results.append(i) + else: + sample_sampling_params = self._build_sampling_params( + greedy=greedy, + max_new_tokens=max_new_tokens, + stop_strings=stop_strings[i] if i < len(stop_strings) else None, + input_len=input_len, + ) - - # TODO: - # Process the outputs - but preserve the original input padding structure - output_ids_list = [] - logprobs_list = [] - generation_lengths = [] - unpadded_sequence_lengths = [] - truncated_list = [] # Track if response was truncated (hit max_tokens) - max_length = 0 - for output in outputs: - max_length = max(max_length, len(output.outputs[0].token_ids)) + sample_requests.append((i, sample_sampling_params, valid_input_ids)) - for i, output in enumerate(outputs): - # Extract generated tokens - sequence_length = input_lengths[i] - generation = output.outputs[0] - generated_tokens = list(generation.token_ids) + # Dispatch concurrently to the SGLang router with bounded concurrency. + # Max concurrency = per-engine concurrency * number of engines. + sglang_server_cfg = self.sglang_cfg["sglang_server"] + max_concurrency = ( + sglang_server_cfg["sglang_server_concurrency"] + * sglang_server_cfg["num_gpus"] + // sglang_server_cfg["num_gpus_per_engine"] + ) - # Calculate total sequence length (original input length + generated tokens) - total_length = padded_input_length + max_length + router_ip = self.sglang_cfg["sglang_router"]["sglang_router_ip"] + router_port = self.sglang_cfg["sglang_router"]["sglang_router_port"] - # Create a new tensor with the right size and fill with padding token - full_output = torch.full( - (total_length,), self.cfg["_pad_token_id"], dtype=input_ids.dtype - ) + semaphore = asyncio.Semaphore(max_concurrency) - # Copy original input (with padding) into the beginning - full_output[:sequence_length] = input_ids[i][:sequence_length] + async def _bounded_generate_one_sample( + idx: int, sp: dict[str, Any], ids: list[int] + ): + async with semaphore: + return await generate_one_sample( + router_ip, router_port, sp, ids, idx + ) - # Add generated tokens after the original input - full_output[sequence_length : sequence_length + len(generated_tokens)] = ( - torch.tensor(generated_tokens) + async def _dispatch_all(): + return await asyncio.gather( + *( + _bounded_generate_one_sample(idx, sp, ids) + for idx, sp, ids in sample_requests + ) ) + + router_results = run(_dispatch_all()) + + # Process the outputs - preserve the original input padding structure. + pad_token_id = self.sglang_cfg["_pad_token_id"] + output_ids_list: list[torch.Tensor] = [] + logprobs_list: list[torch.Tensor] = [] + generation_lengths_list: list[int] = [] + unpadded_sequence_lengths_list: list[int] = [] + truncated_list: list[bool] = [] + + # First pass: compute max unpadded (input + generation) length across the batch + # so every sample can be right-padded to the same width. + max_length = 0 + for i, (_, new_tokens, _, _) in enumerate(router_results): + input_len = input_lengths[i].item() + max_length = max(max_length, input_len + len(new_tokens)) + total_length = max(max_length, padded_input_length) + if len(skip_results): + total_length = max(total_length, context_length) - output_ids_list.append(full_output) + for i in range(batch_size): + input_len = input_lengths[i].item() + full_output = torch.full( + (total_length,), pad_token_id, dtype=input_ids.dtype + ) full_logprobs = torch.zeros(total_length, dtype=torch.float32) - if hasattr(generation, "logprobs") and generation.logprobs: - try: - for idx, logprob_dict in enumerate(generation.logprobs): - if logprob_dict: - position = sequence_length + idx - full_logprobs[position] = next(iter(logprob_dict.items()))[ - 1 - ].logprob - except Exception: - import traceback - - traceback.print_exc() + if i in skip_results: - logprobs_list.append(full_logprobs) + else: - response_length = sequence_length + len(generated_tokens) - generation_lengths.append(len(generated_tokens)) - unpadded_sequence_lengths.append(response_length) - # Check if response was truncated (hit max_tokens length limit) - is_truncated = generation.finish_reason == "length" - truncated_list.append(is_truncated) + for i, (_, new_tokens, new_logprobs, is_truncated) in enumerate(router_results): + input_len = input_lengths[i].item() + generation_length = len(new_tokens) + unpadded_length = input_len + generation_length - assert response_length <= self.llm.llm_engine.model_config.max_model_len, ( - f"response_length={response_length} > max_model_len={self.llm.llm_engine.model_config.max_model_len}, which should not happen. Please check this behavior in isolation by running `uv run --extra vllm tools/model_diagnostics/1.max_model_len_respected.py {self.llm.llm_engine.model_config.model}` and raise this issue with the vllm team." + # Build full output: [input tokens | generated tokens | pad] + full_output = torch.full( + (total_length,), pad_token_id, dtype=input_ids.dtype ) + full_output[:input_len] = input_ids[i][:input_len] + if new_tokens: + full_output[input_len : input_len + generation_length] = torch.tensor( + new_tokens, dtype=input_ids.dtype + ) + + # Logprobs: zeros for input tokens, actual logprobs at generated positions. + full_logprobs = torch.zeros(total_length, dtype=torch.float32) + if new_logprobs: + for idx, logprob in enumerate(new_logprobs): + full_logprobs[input_len + idx] = logprob - # Create return data conforming to GenerationOutputSpec - output_ids = torch.stack(output_ids_list) - logprobs = torch.stack(logprobs_list) + output_ids_list.append(full_output) + logprobs_list.append(full_logprobs) + generation_lengths_list.append(generation_length) + unpadded_sequence_lengths_list.append(unpadded_length) + truncated_list.append(bool(is_truncated)) return_data = BatchedDataDict[GenerationOutputSpec]( { - "output_ids": output_ids, - "logprobs": logprobs, + "output_ids": torch.stack(output_ids_list), + "logprobs": torch.stack(logprobs_list), "generation_lengths": torch.tensor( - generation_lengths, dtype=torch.long + generation_lengths_list, dtype=torch.long ), "unpadded_sequence_lengths": torch.tensor( - unpadded_sequence_lengths, dtype=torch.long + unpadded_sequence_lengths_list, dtype=torch.long ), "truncated": torch.tensor(truncated_list, dtype=torch.bool), } ) return return_data + + async def generate_async( + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: + """Generate a single sample using SGLang, yielding the result when ready. + + This mirrors ``VllmGenerationWorker.generate_async``: it is restricted to + single-sample batches (``batch_size == 1``), so the surrounding + ``asyncio.create_task(...) / asyncio.as_completed(...)`` fan-out used in + the vLLM version is unnecessary here — we simply ``await`` the single + ``generate_one_sample`` call directly and yield its output. + + Args: + data: BatchedDataDict with input_ids and input_lengths (batch_size must be 1) + greedy: Whether to use greedy decoding instead of sampling + + Yields: + Tuple of (original_index, BatchedDataDict conforming to GenerationOutputSpec) + """ + # Handle empty input case + if len(data["input_ids"]) == 0: + return + + verify_right_padding(data, pad_value=self.sglang_cfg["_pad_token_id"]) + + input_ids_batch = data["input_ids"] + input_lengths_batch = data["input_lengths"] + batch_size = input_ids_batch.shape[0] + + # Restrict to single-sample batches, matching the vLLM async contract. + assert batch_size == 1, ( + f"generate_async is restricted to handle only single samples, " + f"but received batch_size={batch_size}. Please handle batching outside this method." + ) + + sample_idx = 0 + input_len = input_lengths_batch[sample_idx].item() + valid_input_ids = input_ids_batch[sample_idx, :input_len].tolist() + + # Merge stop strings for this single sample. + batch_stop_strings: list[list[str]] = data.get("stop_strings", []) + stop_strings = self._merge_stop_strings(batch_stop_strings) + per_sample_stop_strings = ( + stop_strings[sample_idx] if sample_idx < len(stop_strings) else None + ) + + context_length = self.sglang_cfg["sglang_cfg"].get("context_length") + remaining_ctx = ( + (context_length - input_len) if context_length is not None else None + ) + + # Short-circuit when the prompt already fills the context window. + if remaining_ctx is not None and remaining_ctx <= 0: + device = input_ids_batch.device + output_ids_single_item_batched = input_ids_batch[ + sample_idx, :input_len + ].unsqueeze(0) + logprobs_single_item = torch.zeros( + (1, input_len), dtype=torch.float32, device=device + ) + empty_result = BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": output_ids_single_item_batched, + "logprobs": logprobs_single_item, + "generation_lengths": torch.tensor( + [0], dtype=torch.long, device=device + ), + "unpadded_sequence_lengths": torch.tensor( + [input_len], dtype=torch.long, device=device + ), + "truncated": torch.tensor( + [False], dtype=torch.bool, device=device + ), + } + ) + yield (sample_idx, empty_result) + return + + sampling_params = self._build_sampling_params( + greedy=greedy, + stop_strings=per_sample_stop_strings, + input_len=input_len, + ) + + router_ip = self.sglang_cfg["sglang_router"]["sglang_router_ip"] + router_port = self.sglang_cfg["sglang_router"]["sglang_router_port"] + + # batch_size == 1, so no task fan-out / as_completed is needed. Just + # await the single coroutine directly. + _, new_tokens, new_logprobs, is_truncated = await generate_one_sample( + router_ip, + router_port, + sampling_params, + valid_input_ids, + sample_idx, + ) + + # Build the single-sample output tensor: [input | generated]. + pad_token_id = self.sglang_cfg["_pad_token_id"] + original_input_ids_single_row = input_ids_batch[sample_idx] + device = original_input_ids_single_row.device + dtype = original_input_ids_single_row.dtype + + num_generated_tokens = len(new_tokens) + final_output_tensor_len = input_len + num_generated_tokens + + output_ids_single_item = torch.full( + (final_output_tensor_len,), pad_token_id, dtype=dtype, device=device + ) + output_ids_single_item[:input_len] = original_input_ids_single_row[:input_len] + if new_tokens: + output_ids_single_item[input_len:final_output_tensor_len] = torch.tensor( + new_tokens, dtype=dtype, device=device + ) + output_ids_single_item_batched = output_ids_single_item.unsqueeze(0) + + # Logprobs: zeros for input tokens, raw floats at generated positions. + logprobs_single_item = torch.zeros( + (1, final_output_tensor_len), dtype=torch.float32, device=device + ) + if new_logprobs: + for idx, logprob in enumerate(new_logprobs): + logprobs_single_item[0, input_len + idx] = logprob + + generation_lengths_tensor = torch.tensor( + [num_generated_tokens], dtype=torch.long, device=device + ) + unpadded_sequence_lengths_tensor = torch.tensor( + [final_output_tensor_len], dtype=torch.long, device=device + ) + truncated_tensor = torch.tensor( + [bool(is_truncated)], dtype=torch.bool, device=device + ) + + result_batch = BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": output_ids_single_item_batched, + "logprobs": logprobs_single_item, + "generation_lengths": generation_lengths_tensor, + "unpadded_sequence_lengths": unpadded_sequence_lengths_tensor, + "truncated": truncated_tensor, + } + ) + + yield (sample_idx, result_batch) + # --------------------------------------------------------------------------- # Port allocation helpers # --------------------------------------------------------------------------- From 2da9c4b09a0731bc2c0e3fd050eb4e2a0681d95d Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 13:10:30 -0700 Subject: [PATCH 10/36] update --- .../generation/redesign/sglang_generation.py | 240 +++++++++--------- 1 file changed, 122 insertions(+), 118 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 00173d51e9..a151a40b49 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -451,28 +451,26 @@ def _build_sampling_params( *, greedy: bool, max_new_tokens: int, - stop_strings: Optinoal[list[str]] = None, - input_len: Optional[int] = None, + stop_strings: list[str] | None = None, ) -> dict[str, Any]: """Build sampling parameters dictionary for SGLang API. Args: greedy: Whether to use greedy decoding (temperature=0.0) - stop_strings: Merged stop strings (not used here, handled per sample) - max_new_tokens: Override max_new_tokens from config if provided - input_len: Input length for this sample (used for context_length adjustment) + max_new_tokens: Max new tokens for this sample (already clamped by caller + against ``context_length - input_length``). + stop_strings: Merged stop strings for this sample. Returns: - Dictionary of sampling parameters compatible with SGLang API + Dictionary of sampling parameters compatible with SGLang API. """ + temperature = 0.0 if greedy else self.sglang_cfg["temperature"] top_k_cfg = self.sglang_cfg.get("top_k") top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1) - if top_k_val != -1: - sampling_params["top_k"] = top_k_val - temperature = 0.0 if greedy else self.sglang_cfg["temperature"] - # Build sampling params dict - sampling_params = { + # Build sampling params dict first, then patch in optional fields so we + # never reference ``sampling_params`` before it's bound. + sampling_params: dict[str, Any] = { "temperature": temperature, "top_p": self.sglang_cfg.get("top_p", 1.0), "max_new_tokens": max_new_tokens, @@ -480,10 +478,13 @@ def _build_sampling_params( "spaces_between_special_tokens": False, } + if top_k_val != -1: + sampling_params["top_k"] = top_k_val + stop_token_ids = self.sglang_cfg.get("stop_token_ids") if stop_token_ids is not None: sampling_params["stop_token_ids"] = stop_token_ids - + if stop_strings is not None and len(stop_strings) > 0: sampling_params["stop"] = stop_strings @@ -534,28 +535,31 @@ def generate( # Build per-sample requests (each sample gets its own sampling params because # max_new_tokens is adjusted against the per-sample input length). sample_requests: list[tuple[int, dict[str, Any], list[int]]] = [] - skip_results = [] + skip_results: set[int] = set() + skip_max_length = 0 for i in range(batch_size): - input_len = input_lengths[i].item() - valid_input_ids = input_ids[i, :input_len].tolist() + input_length = input_lengths[i].item() + valid_input_ids = input_ids[i, :input_length].tolist() if context_length is not None: - max_new_tokens = min(self.sglang_cfg["max_new_tokens"], context_length - input_len) + max_new_tokens = min( + self.sglang_cfg["max_new_tokens"], context_length - input_length + ) else: max_new_tokens = self.sglang_cfg["max_new_tokens"] max_new_tokens = max(0, max_new_tokens) if max_new_tokens == 0: - skip_results.append(i) - else: - sample_sampling_params = self._build_sampling_params( - greedy=greedy, - max_new_tokens=max_new_tokens, - stop_strings=stop_strings[i] if i < len(stop_strings) else None, - input_len=input_len, - ) + skip_results.add(i) + skip_max_length = max(skip_max_length, input_length) + continue - sample_requests.append((i, sample_sampling_params, valid_input_ids)) + sample_sampling_params = self._build_sampling_params( + greedy=greedy, + max_new_tokens=max_new_tokens, + stop_strings=stop_strings[i] if i < len(stop_strings) else None, + ) + sample_requests.append((i, sample_sampling_params, valid_input_ids)) # Dispatch concurrently to the SGLang router with bounded concurrency. # Max concurrency = per-engine concurrency * number of engines. @@ -579,16 +583,25 @@ async def _bounded_generate_one_sample( router_ip, router_port, sp, ids, idx ) - async def _dispatch_all(): - return await asyncio.gather( + async def _dispatch_all() -> dict[int, tuple[list[int], list[float], bool]]: + gathered = await asyncio.gather( *( _bounded_generate_one_sample(idx, sp, ids) for idx, sp, ids in sample_requests ) ) - - router_results = run(_dispatch_all()) - + # generate_one_sample returns (index, tokens, logprobs, truncated). + # Re-key by the original sample index so downstream code can look up + # results directly without sorting. + return { + returned_idx: (new_tokens, new_logprobs, is_truncated) + for returned_idx, new_tokens, new_logprobs, is_truncated in gathered + } + + router_results: dict[int, tuple[list[int], list[float], bool]] = ( + run(_dispatch_all()) if sample_requests else {} + ) + # Process the outputs - preserve the original input padding structure. pad_token_id = self.sglang_cfg["_pad_token_id"] output_ids_list: list[torch.Tensor] = [] @@ -597,48 +610,42 @@ async def _dispatch_all(): unpadded_sequence_lengths_list: list[int] = [] truncated_list: list[bool] = [] - # First pass: compute max unpadded (input + generation) length across the batch - # so every sample can be right-padded to the same width. - max_length = 0 - for i, (_, new_tokens, _, _) in enumerate(router_results): - input_len = input_lengths[i].item() - max_length = max(max_length, input_len + len(new_tokens)) + # First pass: compute total_length as the max over all samples of + # (input_length + generation_length). Skipped samples contribute only + # their input_length (already tracked in ``skip_max_length``). + max_length = skip_max_length + for returned_idx, (returned_tokens, _, _) in router_results.items(): + sample_input_length = input_lengths[returned_idx].item() + max_length = max(max_length, sample_input_length + len(returned_tokens)) total_length = max(max_length, padded_input_length) - if len(skip_results): - total_length = max(total_length, context_length) + # Second pass: materialize the output tensors, using a single set of + # local variable names (``generation_length`` / ``unpadded_length`` are + # always Python ints; tensor promotion happens only at the final stack). for i in range(batch_size): - input_len = input_lengths[i].item() + input_length = input_lengths[i].item() full_output = torch.full( (total_length,), pad_token_id, dtype=input_ids.dtype ) full_logprobs = torch.zeros(total_length, dtype=torch.float32) - if i in skip_results: + full_output[:input_length] = input_ids[i][:input_length] + if i in skip_results: + generation_length = 0 + is_truncated = False else: + new_tokens, new_logprobs, is_truncated = router_results[i] + generation_length = len(new_tokens) + if new_tokens: + full_output[input_length : input_length + generation_length] = ( + torch.tensor(new_tokens, dtype=input_ids.dtype) + ) + if new_logprobs: + full_logprobs[input_length : input_length + len(new_logprobs)] = ( + torch.tensor(new_logprobs, dtype=torch.float32) + ) - - for i, (_, new_tokens, new_logprobs, is_truncated) in enumerate(router_results): - input_len = input_lengths[i].item() - generation_length = len(new_tokens) - unpadded_length = input_len + generation_length - - # Build full output: [input tokens | generated tokens | pad] - full_output = torch.full( - (total_length,), pad_token_id, dtype=input_ids.dtype - ) - full_output[:input_len] = input_ids[i][:input_len] - if new_tokens: - full_output[input_len : input_len + generation_length] = torch.tensor( - new_tokens, dtype=input_ids.dtype - ) - - # Logprobs: zeros for input tokens, actual logprobs at generated positions. - full_logprobs = torch.zeros(total_length, dtype=torch.float32) - if new_logprobs: - for idx, logprob in enumerate(new_logprobs): - full_logprobs[input_len + idx] = logprob - + unpadded_length = input_length + generation_length output_ids_list.append(full_output) logprobs_list.append(full_logprobs) generation_lengths_list.append(generation_length) @@ -667,13 +674,6 @@ async def generate_async( greedy: bool = False, ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: """Generate a single sample using SGLang, yielding the result when ready. - - This mirrors ``VllmGenerationWorker.generate_async``: it is restricted to - single-sample batches (``batch_size == 1``), so the surrounding - ``asyncio.create_task(...) / asyncio.as_completed(...)`` fan-out used in - the vLLM version is unnecessary here — we simply ``await`` the single - ``generate_one_sample`` call directly and yield its output. - Args: data: BatchedDataDict with input_ids and input_lengths (batch_size must be 1) greedy: Whether to use greedy decoding instead of sampling @@ -698,29 +698,32 @@ async def generate_async( ) sample_idx = 0 - input_len = input_lengths_batch[sample_idx].item() - valid_input_ids = input_ids_batch[sample_idx, :input_len].tolist() - - # Merge stop strings for this single sample. - batch_stop_strings: list[list[str]] = data.get("stop_strings", []) - stop_strings = self._merge_stop_strings(batch_stop_strings) - per_sample_stop_strings = ( - stop_strings[sample_idx] if sample_idx < len(stop_strings) else None - ) + input_length = input_lengths_batch[sample_idx].item() + original_input_ids_single_row = input_ids_batch[sample_idx] + device = original_input_ids_single_row.device + dtype = original_input_ids_single_row.dtype + pad_token_id = self.sglang_cfg["_pad_token_id"] + # Clamp max_new_tokens against the per-sample remaining context window, + # mirroring the logic in ``generate``. context_length = self.sglang_cfg["sglang_cfg"].get("context_length") - remaining_ctx = ( - (context_length - input_len) if context_length is not None else None - ) - - # Short-circuit when the prompt already fills the context window. - if remaining_ctx is not None and remaining_ctx <= 0: - device = input_ids_batch.device - output_ids_single_item_batched = input_ids_batch[ - sample_idx, :input_len + if context_length is not None: + max_new_tokens = min( + self.sglang_cfg["max_new_tokens"], context_length - input_length + ) + else: + max_new_tokens = self.sglang_cfg["max_new_tokens"] + max_new_tokens = max(0, max_new_tokens) + + # Short-circuit when there is no room left in the context window. Yield + # a pure-input row (generation_length=0, truncated=False) without + # touching the SGLang router. + if max_new_tokens == 0: + output_ids_single_item_batched = original_input_ids_single_row[ + :input_length ].unsqueeze(0) logprobs_single_item = torch.zeros( - (1, input_len), dtype=torch.float32, device=device + (1, input_length), dtype=torch.float32, device=device ) empty_result = BatchedDataDict[GenerationOutputSpec]( { @@ -730,7 +733,7 @@ async def generate_async( [0], dtype=torch.long, device=device ), "unpadded_sequence_lengths": torch.tensor( - [input_len], dtype=torch.long, device=device + [input_length], dtype=torch.long, device=device ), "truncated": torch.tensor( [False], dtype=torch.bool, device=device @@ -740,14 +743,22 @@ async def generate_async( yield (sample_idx, empty_result) return + # Merge stop strings for this single sample. + batch_stop_strings: list[list[str]] = data.get("stop_strings", []) + stop_strings = self._merge_stop_strings(batch_stop_strings) + per_sample_stop_strings = ( + stop_strings[sample_idx] if sample_idx < len(stop_strings) else None + ) + sampling_params = self._build_sampling_params( greedy=greedy, + max_new_tokens=max_new_tokens, stop_strings=per_sample_stop_strings, - input_len=input_len, ) router_ip = self.sglang_cfg["sglang_router"]["sglang_router_ip"] router_port = self.sglang_cfg["sglang_router"]["sglang_router_port"] + valid_input_ids = original_input_ids_single_row[:input_length].tolist() # batch_size == 1, so no task fan-out / as_completed is needed. Just # await the single coroutine directly. @@ -760,49 +771,42 @@ async def generate_async( ) # Build the single-sample output tensor: [input | generated]. - pad_token_id = self.sglang_cfg["_pad_token_id"] - original_input_ids_single_row = input_ids_batch[sample_idx] - device = original_input_ids_single_row.device - dtype = original_input_ids_single_row.dtype - - num_generated_tokens = len(new_tokens) - final_output_tensor_len = input_len + num_generated_tokens + generation_length = len(new_tokens) + unpadded_length = input_length + generation_length output_ids_single_item = torch.full( - (final_output_tensor_len,), pad_token_id, dtype=dtype, device=device + (unpadded_length,), pad_token_id, dtype=dtype, device=device ) - output_ids_single_item[:input_len] = original_input_ids_single_row[:input_len] + output_ids_single_item[:input_length] = original_input_ids_single_row[ + :input_length + ] if new_tokens: - output_ids_single_item[input_len:final_output_tensor_len] = torch.tensor( + output_ids_single_item[input_length:unpadded_length] = torch.tensor( new_tokens, dtype=dtype, device=device ) - output_ids_single_item_batched = output_ids_single_item.unsqueeze(0) # Logprobs: zeros for input tokens, raw floats at generated positions. logprobs_single_item = torch.zeros( - (1, final_output_tensor_len), dtype=torch.float32, device=device + (1, unpadded_length), dtype=torch.float32, device=device ) if new_logprobs: - for idx, logprob in enumerate(new_logprobs): - logprobs_single_item[0, input_len + idx] = logprob - - generation_lengths_tensor = torch.tensor( - [num_generated_tokens], dtype=torch.long, device=device - ) - unpadded_sequence_lengths_tensor = torch.tensor( - [final_output_tensor_len], dtype=torch.long, device=device - ) - truncated_tensor = torch.tensor( - [bool(is_truncated)], dtype=torch.bool, device=device - ) + logprobs_single_item[ + 0, input_length : input_length + len(new_logprobs) + ] = torch.tensor(new_logprobs, dtype=torch.float32, device=device) result_batch = BatchedDataDict[GenerationOutputSpec]( { - "output_ids": output_ids_single_item_batched, + "output_ids": output_ids_single_item.unsqueeze(0), "logprobs": logprobs_single_item, - "generation_lengths": generation_lengths_tensor, - "unpadded_sequence_lengths": unpadded_sequence_lengths_tensor, - "truncated": truncated_tensor, + "generation_lengths": torch.tensor( + [generation_length], dtype=torch.long, device=device + ), + "unpadded_sequence_lengths": torch.tensor( + [unpadded_length], dtype=torch.long, device=device + ), + "truncated": torch.tensor( + [bool(is_truncated)], dtype=torch.bool, device=device + ), } ) From 348966361fe11a8a75b33c3fae0ca191833650ba Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 14:01:38 -0700 Subject: [PATCH 11/36] update --- .../models/generation/redesign/async_utils.py | 33 +++ nemo_rl/models/generation/redesign/config.py | 1 + nemo_rl/models/generation/redesign/pg.py | 183 -------------- .../redesign/{utils.py => ray_http_utils.py} | 133 +++++++++- .../generation/redesign/sglang_generation.py | 235 +++++++----------- .../generation/redesign/sglang_worker.py | 16 +- 6 files changed, 258 insertions(+), 343 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/async_utils.py delete mode 100644 nemo_rl/models/generation/redesign/pg.py rename nemo_rl/models/generation/redesign/{utils.py => ray_http_utils.py} (72%) diff --git a/nemo_rl/models/generation/redesign/async_utils.py b/nemo_rl/models/generation/redesign/async_utils.py new file mode 100644 index 0000000000..a1d35b9614 --- /dev/null +++ b/nemo_rl/models/generation/redesign/async_utils.py @@ -0,0 +1,33 @@ +import asyncio +import threading + + +class AsyncLoopThread: + def __init__(self): + self.loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + + def _start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def run(self, coro): + # Schedule a coroutine onto the loop and block until it's done + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + + +# Create one global instance +async_loop = None + + +def get_async_loop(): + global async_loop + if async_loop is None: + async_loop = AsyncLoopThread() + return async_loop + + +def run(coro): + """Run a coroutine in the background event loop.""" + return get_async_loop().run(coro) diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py index 00f8b68fcb..7f74f47529 100644 --- a/nemo_rl/models/generation/redesign/config.py +++ b/nemo_rl/models/generation/redesign/config.py @@ -104,6 +104,7 @@ class SGLangRouter(TypedDict): sglang_router_ip: NotRequired[str] sglang_router_port: NotRequired[int] router_policy: NotRequired[str] + use_distributed_post: NotRequired[bool] class SGLangConfig(GenerationConfig): """Configuration for SGLang runtime.""" diff --git a/nemo_rl/models/generation/redesign/pg.py b/nemo_rl/models/generation/redesign/pg.py deleted file mode 100644 index 7cba97e1bd..0000000000 --- a/nemo_rl/models/generation/redesign/pg.py +++ /dev/null @@ -1,183 +0,0 @@ -import logging -import socket - -import ray -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -from .actor_group import RayTrainGroup -from .rollout import RolloutManager - -logger = logging.getLogger(__name__) - - -def sort_key(x): - index, node_identifier, gpu_id = x - # Sort by node IP number and then by GPU ID - try: - # try to parse it as an IP address. - ip_address = node_identifier - node_ip_parts = list(map(int, ip_address.split("."))) - except ValueError: - # Try to resolve the hostname to an IP address. - try: - ip_address = socket.gethostbyname(node_identifier) - node_ip_parts = list(map(int, ip_address.split("."))) - except (socket.gaierror, TypeError): - # Instead, we convert each character of the original identifier string - # to its ASCII value. This provides a stable and consistent numerical - # representation that allows for sorting. - node_ip_parts = [ord(c) for c in node_identifier] - - return (node_ip_parts, gpu_id) - - -def _create_placement_group(num_gpus): - """Create a placement group with the specified number of GPUs.""" - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] - pg = placement_group(bundles, strategy="PACK") - num_bundles = len(bundles) - - ray.get(pg.ready()) - # use info actor to get the GPU id - info_actors = [] - for i in range(num_bundles): - info_actors.append( - InfoActor.options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=i, - ) - ).remote() - ) - gpu_ids = ray.get([actor.get_ip_and_gpu_id.remote() for actor in info_actors]) - for actor in info_actors: - ray.kill(actor) - - bundle_infos = [(i, gpu_ids[i][0], gpu_ids[i][1]) for i in range(num_bundles)] - sorted_bundle_infos = sorted(bundle_infos, key=sort_key) - pg_reordered_bundle_indices = [info[0] for info in sorted_bundle_infos] - # Map from logical index -> physical GPU ID - pg_reordered_gpu_ids = [gpu_ids[info[0]][1] for info in sorted_bundle_infos] - - for i in range(num_bundles): - actual_bundle_index = pg_reordered_bundle_indices[i] - logger.info( - f" bundle {i:4}, actual_bundle_index: {actual_bundle_index:4}, " - f"node: {gpu_ids[actual_bundle_index][0]}, gpu: {gpu_ids[actual_bundle_index][1]}" - ) - - return pg, pg_reordered_bundle_indices, pg_reordered_gpu_ids - - -def create_placement_groups(args): - """Create placement groups for actor and rollout engines.""" - - num_gpus = 0 - if args.debug_train_only: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = 0 - if args.use_critic: - num_gpus += args.critic_num_nodes * args.critic_num_gpus_per_node - critic_offset = args.actor_num_nodes * args.actor_num_gpus_per_node - elif args.debug_rollout_only: - num_gpus = args.rollout_num_gpus - rollout_offset = 0 - elif args.colocate: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = 0 - if args.use_critic: - num_gpus += args.critic_num_nodes * args.critic_num_gpus_per_node - critic_offset = args.actor_num_nodes * args.actor_num_gpus_per_node - else: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus - rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node - if args.use_critic: - num_gpus += args.critic_num_nodes * args.critic_num_gpus_per_node - critic_offset = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset += args.critic_num_nodes * args.critic_num_gpus_per_node - - logger.info(f"Creating placement group with {num_gpus} GPUs...") - pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids = _create_placement_group(num_gpus) - - rollout_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[rollout_offset:] - rollout_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[rollout_offset:] - if args.use_critic: - critic_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[critic_offset:] - critic_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[critic_offset:] - - return { - "actor": (pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids), - "critic": (pg, critic_pg_reordered_bundle_indices, critic_pg_reordered_gpu_ids) if args.use_critic else None, - "rollout": (pg, rollout_pg_reordered_bundle_indices, rollout_pg_reordered_gpu_ids), - } - - -def allocate_train_group(args, num_nodes, num_gpus_per_node, pg): - return RayTrainGroup( - args=args, - num_nodes=num_nodes, - num_gpus_per_node=num_gpus_per_node, - pg=pg, - num_gpus_per_actor=0.4, - ) - - -def create_training_models(args, pgs, rollout_manager): - actor_model = allocate_train_group( - args=args, - num_nodes=args.actor_num_nodes, - num_gpus_per_node=args.actor_num_gpus_per_node, - pg=pgs["actor"], - ) - if args.use_critic: - critic_model = allocate_train_group( - args=args, - num_nodes=args.critic_num_nodes, - num_gpus_per_node=args.critic_num_gpus_per_node, - pg=pgs["critic"], - ) - critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False) - else: - critic_model = None - - start_rollout_ids = ray.get( - actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss) - ) - - assert len(set(start_rollout_ids)) == 1 - if args.start_rollout_id is None: - args.start_rollout_id = start_rollout_ids[0] - - if args.use_critic: - ray.get(critic_init_handle) - actor_model.connect(critic_model) - - actor_model.set_rollout_manager(rollout_manager) - if args.rollout_global_dataset: - ray.get(rollout_manager.load.remote(args.start_rollout_id - 1)) - - return actor_model, critic_model - - -def create_rollout_manager(args, pg): - rollout_manager = RolloutManager.options( - num_cpus=1, - num_gpus=0, - ).remote(args, pg) - - # calculate num_rollout from num_epoch - num_rollout_per_epoch = None - if args.num_rollout is None: - num_rollout_per_epoch = ray.get(rollout_manager.get_num_rollout_per_epoch.remote()) - args.num_rollout = num_rollout_per_epoch * args.num_epoch - assert args.num_rollout > 0 - - if args.check_weight_update_equal: - ray.get(rollout_manager.check_weights.remote(action="snapshot")) - ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) - - if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) - - return rollout_manager, num_rollout_per_epoch diff --git a/nemo_rl/models/generation/redesign/utils.py b/nemo_rl/models/generation/redesign/ray_http_utils.py similarity index 72% rename from nemo_rl/models/generation/redesign/utils.py rename to nemo_rl/models/generation/redesign/ray_http_utils.py index 2a781f5761..f9b7afac32 100644 --- a/nemo_rl/models/generation/redesign/utils.py +++ b/nemo_rl/models/generation/redesign/ray_http_utils.py @@ -1,3 +1,119 @@ +import asyncio +import io +import ipaddress +import json +import logging +import multiprocessing +import os +import random +import socket +from multiprocessing.reduction import ForkingPickler +from typing import Callable, Union + +import httpx +import pybase64 +import ray +import torch +from torch.multiprocessing import reductions + +from nemo_rl.models.generation.redesign.config import SGLangConfig + +logger = logging.getLogger(__name__) + + +class RayActor: + """Base class for Ray actors providing node IP / free port helpers.""" + + @staticmethod + def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): + return get_current_node_ip(), get_free_port( + start_port=start_port, consecutive=consecutive + ) + + def get_master_addr_and_port(self): + return self.master_addr, self.master_port + + +class MultiprocessingSerializer: # pragma: no cover + """Serialize/deserialize Python objects using ForkingPickler for IPC. + + This class enables serialization of objects (including CUDA tensors with IPC + handles) for transfer between processes via HTTP or other mechanisms. + + Original source (sglang v0.5.2): + https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 + """ + + @staticmethod + def serialize(obj, output_str: bool = False): + """Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = pybase64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = pybase64.b64decode(data, validate=True) + + return ForkingPickler.loads(data) + + + +NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", + "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", + "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", + "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", +] + + +@ray.remote +class Lock(RayActor): + def __init__(self): + self._locked = False # False: unlocked, True: locked + + def acquire(self): + """ + Try to acquire the lock. Returns True if acquired, False otherwise. + Caller should retry until it returns True. + """ + if not self._locked: + self._locked = True + return True + return False + + def release(self): + """Release the lock, allowing others to acquire.""" + assert self._locked, "Lock is not acquired, cannot release." + self._locked = False @@ -27,9 +143,6 @@ def is_port_available(port): def get_host_info(): hostname = socket.gethostname() - if env_overwrite_local_ip := os.getenv(MILES_HOST_IP_ENV, None): - return hostname, env_overwrite_local_ip - def _is_loopback(ip): return ip.startswith("127.") or ip == "::1" @@ -194,13 +307,17 @@ async def _post(client, url, payload, max_retries=60, action="post"): return output -def init_http_client(args: SglangSpecificArgs): +def init_http_client(args: SGLangConfig): """Initialize HTTP client and optionally enable distributed POST via Ray.""" global _http_client, _client_concurrency, _distributed_post_enabled - if not args.rollout_num_gpus: + if not args.get("sglang_server").get("num_gpus"): return - _client_concurrency = args["sglang_server"]["sglang_server_concurrency"] * args["sglang_server"]["num_gpus"] // args["sglang_server"]["num_gpus_per_engine"] + _client_concurrency = ( + args["sglang_server"]["sglang_server_concurrency"] + * args["sglang_server"]["num_gpus"] + // args["sglang_server"]["num_gpus_per_engine"] + ) if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency), @@ -208,12 +325,12 @@ def init_http_client(args: SglangSpecificArgs): ) # Optionally initialize distributed POST via Ray without changing interfaces - if args.use_distributed_post: + if args.get("sglang_router").get("use_distributed_post"): _init_ray_distributed_post(args) _distributed_post_enabled = True -def _init_ray_distributed_post(args: SglangSpecificArgs): +def _init_ray_distributed_post(args: SGLangConfig): """Initialize one or more Ray async actors per node for HTTP POST. Uses NodeAffinitySchedulingStrategy to place actors on distinct nodes. diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index a151a40b49..3db55a2a17 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -14,118 +14,43 @@ import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS - -from miles.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig -from miles.backends.sglang_utils.sglang_engine import SGLangEngine -from miles.rollout.base_types import ( - RolloutFnConstructorInput, - RolloutFnEvalInput, - RolloutFnTrainInput, - call_rollout_fn, +from sglang.srt.constants import ( + GPU_MEMORY_TYPE_CUDA_GRAPH, + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, ) -from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function -from miles.utils import tracking_utils -from miles.utils.environ import enable_experimental_rollout_refactor -from miles.utils.health_monitor import RolloutHealthMonitor -from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client -from miles.utils.iter_utils import group_by -from miles.utils.logging_utils import configure_logger -from miles.utils.metric_checker import MetricChecker -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix -from miles.utils.misc import load_function -from miles.utils.ray_utils import Box -from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions -from miles.utils.tracking_utils import init_tracking -from miles.utils.types import Sample - -from ..utils.metric_utils import has_repetition -from .utils import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, Lock + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ( + ClusterConfig, + RayVirtualCluster, + get_reordered_bundle_and_gpu_ids, +) +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationInterface, + GenerationOutputSpec, + verify_right_padding, +) +from nemo_rl.models.generation.redesign.async_utils import run +from nemo_rl.models.generation.redesign.config import SGLangConfig +from nemo_rl.models.generation.redesign.ray_http_utils import ( + NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, + Lock, + _wrap_ipv6, + find_available_port, + get_host_info, + init_http_client, + post, + run_router, +) +from nemo_rl.models.generation.redesign.sglang_worker import SGLangEngine logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# ServerGroup / RolloutServer abstractions -# --------------------------------------------------------------------------- - -# use_unified_pg = True for Nemo -# if use_unified_pg: -# # Create a single unified placement group for cross-node model parallelism -# all_bundles = [] -# for bundle_count in self._bundle_ct_per_node_list: -# for _ in range(bundle_count): -# all_bundles.append( -# {"CPU": num_cpus_per_bundle, "GPU": num_gpus_per_bundle} -# ) - -# placement_groups = [ -# placement_group( -# bundles=all_bundles, strategy=strategy, name=f"{self.name}-unified" -# ) -# ] - -# pg = cluster._init_placement_groups(strategy="PACK", use_unified_pg=True)[0] -# pg_reordered_bundle_indices = cluster._get_sorted_bundle_indices() - - -class AsyncLoopThread: - def __init__(self): - self.loop = asyncio.new_event_loop() - self._thread = threading.Thread(target=self._start_loop, daemon=True) - self._thread.start() - - def _start_loop(self): - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def run(self, coro): - # Schedule a coroutine onto the loop and block until it's done - return asyncio.run_coroutine_threadsafe(coro, self.loop).result() - -# Create one global instance -async_loop = None - -def get_async_loop(): - global async_loop - if async_loop is None: - async_loop = AsyncLoopThread() - return async_loop - -def run(coro): - """Run a coroutine in the background event loop.""" - return get_async_loop().run(coro) - - -async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_params, input_ids, index: int): - """Generate using traditional SGLang router with token-based workflow""" - url = f"http://{sglang_router_ip}:{sglang_router_port}/generate" - - # Prepare payload for sglang server - payload = { - "sampling_params": sampling_params, - "return_logprob": True, - "input_ids": input_ids, - } - - output = await post(url, payload) - - if "output_token_logprobs" in output["meta_info"]: - response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] - response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] - else: - response_tokens, response_log_probs = [], [] - - response_truncated = False - if response_truncated["meta_info"]["output_token_logprobs"] is not None and output["meta_info"]["output_token_logprobs"] == "length": - response_truncated = True - - return index, response_tokens, response_log_probs, response_truncated - - @dataclasses.dataclass class ServerGroup: """A group of homogeneous SGLang engines with the same configuration. @@ -146,6 +71,8 @@ class ServerGroup: model_path: str | None = None router_ip: str | None = None router_port: int | None = None + cluster_cfg: Any = None + sglang_cfg: Any = None @property def nodes_per_engine(self): @@ -155,10 +82,6 @@ def nodes_per_engine(self): def engines(self): """Node-0 engines only (for multi-node serving).""" return self.all_engines[:: self.nodes_per_engine] - - @num_new_engines.setter - def num_new_engines(self, value): - self.num_new_engines = value @property def engine_gpu_counts(self) -> list[int]: @@ -224,7 +147,8 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis "env_vars": env_vars, }, ).remote( - self.args, + self.cluster_cfg, + self.sglang_cfg, rank=global_rank, base_gpu_id=base_gpu_id, num_gpus_per_engine=self.num_gpus_per_engine, @@ -240,8 +164,9 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis base_port = max(port_cursors.values()) if port_cursors else 15000 addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( + cluster_cfg=self.cluster_cfg, + sglang_cfg=self.sglang_cfg, rollout_engines=rollout_engines, - num_gpus_per_engine=self.num_gpus_per_engine, rank_offset=self.rank_offset, base_port=base_port, ) @@ -258,34 +183,25 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis def recover(self): """Recover dead engines across all active groups, overlapping init.""" - dead_per_group = [[i for i, engine in enumerate(g.all_engines) if engine is None] for g in self.server_groups] + dead_indices = [i for i, engine in enumerate(self.all_engines) if engine is None] - all_handles = [] port_cursors: dict[int, int] = {} - for g in self.server_groups: - handles, port_cursors = g.start_engines(port_cursors) - all_handles.extend(handles) - if all_handles: - ray.get(all_handles) + handles = self.start_engines(port_cursors) + if handles: + ray.get(handles) release_handles = [] updatable_new_engines = [] - non_updatable_groups_engines: list[tuple[str, list]] = [] - for g, dead_indices in zip(self.server_groups, dead_per_group, strict=True): - assert g.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" - if g.needs_offload and dead_indices: - new_engines = [g.all_engines[i] for i in dead_indices] - release_handles.extend(engine.release_memory_occupation.remote() for engine in new_engines) - if self.update_weights: - updatable_new_engines.extend(new_engines) - elif g.model_path: - non_updatable_groups_engines.append((g.model_path, new_engines)) + + assert self.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" + if self.needs_offload and dead_indices: + new_engines = [self.all_engines[i] for i in dead_indices] + release_handles.extend(engine.release_memory_occupation.remote() for engine in new_engines) + updatable_new_engines.extend(new_engines) if release_handles: ray.get(release_handles) all_resume_engines = updatable_new_engines[:] - for _model_path, engines in non_updatable_groups_engines: - all_resume_engines.extend(engines) if all_resume_engines: ray.get( [ @@ -321,17 +237,12 @@ def onload_weights_from_disk(self): return [ engine.update_weights_from_disk.remote(self.model_path) for engine in self.engines if engine is not None ] - -# --------------------------------------------------------------------------- -# SGLangGeneration -# --------------------------------------------------------------------------- + class SGLangGeneration(GenerationInterface): """The class to run rollout and convert rollout data to training data.""" def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglang_cfg: SGLangConfig): - configure_logger() - self.cluster = cluster self.cluster_cfg = cluster_cfg self.sglang_cfg = sglang_cfg @@ -341,16 +252,15 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan ) self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids = get_reordered_bundle_and_gpu_ids(self.pg) - init_http_client(args) - self.server_group = start_rollout_servers(args, (self.pg, self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids)) + init_http_client(sglang_cfg) + self.server_group = start_rollout_servers( + sglang_cfg, + cluster_cfg, + (self.pg, self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids), + ) self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() - - def dispose(self): - if self._metric_checker is not None: - self._metric_checker.dispose() - @property def rollout_engines(self): """All node-0 engines across all servers / models.""" @@ -394,7 +304,8 @@ def recover_updatable_engines(self): updates from training). """ server_group = self.server_group - if self.rollout_id == -1 or server_group is None: + + if server_group is None: engines = server_group.engines if server_group else [] gpu_counts = server_group.engine_gpu_counts if server_group else [] gpu_offsets = server_group.engine_gpu_offsets if server_group else [] @@ -812,10 +723,40 @@ async def generate_async( yield (sample_idx, result_batch) + # --------------------------------------------------------------------------- -# Port allocation helpers +# Generate one sample helper # --------------------------------------------------------------------------- +async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_params, input_ids, index: int): + """Generate using traditional SGLang router with token-based workflow""" + url = f"http://{sglang_router_ip}:{sglang_router_port}/generate" + + # Prepare payload for sglang server + payload = { + "sampling_params": sampling_params, + "return_logprob": True, + "input_ids": input_ids, + } + + output = await post(url, payload) + + if "output_token_logprobs" in output["meta_info"]: + response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]] + else: + response_tokens, response_log_probs = [], [] + response_truncated = False + if output["meta_info"]["output_token_logprobs"] is not None and output["meta_info"]["output_token_logprobs"] == "length": + response_truncated = True + + return index, response_tokens, response_log_probs, response_truncated + + + +# --------------------------------------------------------------------------- +# Port allocation helpers +# --------------------------------------------------------------------------- def _allocate_rollout_engine_addr_and_ports_normal( *, cluster_cfg, @@ -980,9 +921,11 @@ def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: rank_offset=engine_offset, gpu_offset=gpu_offset, needs_offload=needs_offload, - model_path= model_path, + model_path=model_path, router_ip=router_ip, router_port=router_port, + cluster_cfg=cluster_cfg, + sglang_cfg=sglang_cfg, ) handles, port_cursors = server_group.start_engines(port_cursors) diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 269e9539ef..76506ce875 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -6,14 +6,19 @@ import time from urllib.parse import quote +import ray import requests import sglang_router from packaging.version import parse from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError -from miles.utils.env_report import collect_and_print_node_env_report -import ray + +from nemo_rl.models.generation.redesign.ray_http_utils import ( + get_current_node_ip, + get_free_port, + get_host_info, +) logger = logging.getLogger(__name__) @@ -492,7 +497,7 @@ def _compute_server_args( "pp_size": sglang_cfg["sglang_cfg"]["pp_size"], "ep_size": sglang_cfg["sglang_cfg"]["ep_size"], # always skip warmup to prevent warmup timeout. - "skip_server_warmup": True, + "skip_server_warmup": sglang_cfg["sglang_cfg"]["skip_server_warmup"], # always enable draft weights cpu backup so that we run training without mtp weights. "enable_draft_weights_cpu_backup": True, } @@ -500,7 +505,7 @@ def _compute_server_args( for key in [ "dtype", "kv_cache_dtype", - " ", + "context_length", "max_running_requests", "chunked_prefill_size", "max_prefill_tokens", @@ -515,5 +520,4 @@ def _compute_server_args( if key in sglang_cfg["sglang_cfg"]: kwargs[key] = sglang_cfg["sglang_cfg"][key] - return kwargs - + return kwargs \ No newline at end of file From cd4d62507b8070111c2dffc70996b39360e81fdc Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 14:13:26 -0700 Subject: [PATCH 12/36] update --- .../models/generation/redesign/http_utils.py | 162 +++++++ nemo_rl/models/generation/redesign/misc.py | 98 +++++ .../generation/redesign/ray_http_utils.py | 410 ------------------ .../models/generation/redesign/ray_utils.py | 147 +++++++ .../generation/redesign/sglang_generation.py | 14 +- .../generation/redesign/sglang_worker.py | 2 +- 6 files changed, 417 insertions(+), 416 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/http_utils.py create mode 100644 nemo_rl/models/generation/redesign/misc.py delete mode 100644 nemo_rl/models/generation/redesign/ray_http_utils.py create mode 100644 nemo_rl/models/generation/redesign/ray_utils.py diff --git a/nemo_rl/models/generation/redesign/http_utils.py b/nemo_rl/models/generation/redesign/http_utils.py new file mode 100644 index 0000000000..ccc83a4dfe --- /dev/null +++ b/nemo_rl/models/generation/redesign/http_utils.py @@ -0,0 +1,162 @@ +import asyncio +import json +import logging + +import httpx + +from nemo_rl.models.generation.redesign.config import SGLangConfig + +logger = logging.getLogger(__name__) + + +_http_client: httpx.AsyncClient | None = None +_client_concurrency: int = 0 + +# Optional Ray-based distributed POST dispatch +_distributed_post_enabled: bool = False +_post_actors: list[object] = [] +_post_actor_idx: int = 0 + + +def _next_actor(): + global _post_actor_idx + if not _post_actors: + return None + actor = _post_actors[_post_actor_idx % len(_post_actors)] + _post_actor_idx = (_post_actor_idx + 1) % len(_post_actors) + return actor + + +async def _post(client, url, payload, max_retries=60, action="post"): + retry_count = 0 + while retry_count < max_retries: + try: + if action in ("delete", "get"): + assert not payload + response = await getattr(client, action)(url) + else: + response = await getattr(client, action)(url, json=payload or {}) + response.raise_for_status() + try: + output = response.json() + except json.JSONDecodeError: + output = response.text + except Exception as e: + retry_count += 1 + + if isinstance(e, httpx.HTTPStatusError): + response_text = e.response.text + else: + response_text = None + + logger.info( + f"Error: {e}, retrying... (attempt {retry_count}/{max_retries}, url={url}, response={response_text})" + ) + if retry_count >= max_retries: + logger.info(f"Max retries ({max_retries}) reached, failing... (url={url})") + raise e + await asyncio.sleep(1) + continue + break + + return output + + +def init_http_client(args: SGLangConfig): + """Initialize HTTP client and optionally enable distributed POST via Ray.""" + global _http_client, _client_concurrency, _distributed_post_enabled + if not args.get("sglang_server").get("num_gpus"): + return + + _client_concurrency = ( + args["sglang_server"]["sglang_server_concurrency"] + * args["sglang_server"]["num_gpus"] + // args["sglang_server"]["num_gpus_per_engine"] + ) + if _http_client is None: + _http_client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=_client_concurrency), + timeout=httpx.Timeout(None), + ) + + # Optionally initialize distributed POST via Ray without changing interfaces + if args.get("sglang_router").get("use_distributed_post"): + _init_ray_distributed_post(args) + _distributed_post_enabled = True + + +def _init_ray_distributed_post(args: SGLangConfig): + """Initialize one or more Ray async actors per node for HTTP POST. + """ + global _post_actors + if _post_actors: + return # Already initialized + + import ray + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + + # Discover alive nodes + nodes = [n for n in ray.nodes() if n.get("Alive")] + if not nodes: + raise RuntimeError("No alive Ray nodes to place HTTP POST actors.") + + # Define the async actor + @ray.remote + class _HttpPosterActor: + def __init__(self, concurrency: int): + # Lazy creation to this actor's event loop + self._client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=max(1, concurrency)), + timeout=httpx.Timeout(None), + ) + + async def do_post(self, url, payload, max_retries=60, action="post"): + return await _post(self._client, url, payload, max_retries, action=action) + + # Create actors per node + created = [] + # Distribute client concurrency across actors (at least 1 per actor) + per_actor_conc = (_client_concurrency + len(nodes)) // len(nodes) + + for node in nodes: + node_id = node["NodeID"] + scheduling = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) + for _ in range(args["sglang_server"]["num_gpus_per_engine"]): + actor = _HttpPosterActor.options( + name=None, + lifetime="detached", + scheduling_strategy=scheduling, + max_concurrency=per_actor_conc, + # Use tiny CPU to schedule + num_cpus=0.001, + ).remote(per_actor_conc) + created.append(actor) + + _post_actors = created + + +# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) +async def post(url, payload, max_retries=60, action="post"): + # If distributed mode is enabled and actors exist, dispatch via Ray. + if _distributed_post_enabled and _post_actors: + try: + import ray + + actor = _next_actor() + if actor is not None: + # Use a thread to avoid blocking the event loop on ray.get + obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) + return await asyncio.to_thread(ray.get, obj_ref) + except Exception as e: + logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") + # fall through to local + + return await _post(_http_client, url, payload, max_retries, action=action) + + +# TODO unify w/ `post` to add retries and remote-execution +async def get(url): + response = await _http_client.get(url) + response.raise_for_status() + output = response.json() + return output diff --git a/nemo_rl/models/generation/redesign/misc.py b/nemo_rl/models/generation/redesign/misc.py new file mode 100644 index 0000000000..2825761d8e --- /dev/null +++ b/nemo_rl/models/generation/redesign/misc.py @@ -0,0 +1,98 @@ +import io +import logging +import multiprocessing +from multiprocessing.reduction import ForkingPickler + +import pybase64 + +logger = logging.getLogger(__name__) + + +NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", + "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", + "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", + "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", +] + + +class MultiprocessingSerializer: # pragma: no cover + """Serialize/deserialize Python objects using ForkingPickler for IPC. + + This class enables serialization of objects (including CUDA tensors with IPC + handles) for transfer between processes via HTTP or other mechanisms. + + Original source (sglang v0.5.2): + https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 + """ + + @staticmethod + def serialize(obj, output_str: bool = False): + """Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = pybase64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = pybase64.b64decode(data, validate=True) + + return ForkingPickler.loads(data) + + +def run_router(args): + try: + from sglang_router.launch_router import launch_router + + router = launch_router(args) + if router is None: + return 1 + return 0 + except Exception as e: + logger.info(e) + return 1 + + +def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None: + """Terminate a process gracefully, with forced kill as fallback. + + Args: + process: The process to terminate + timeout: Seconds to wait for graceful termination before forcing kill + """ + if not process.is_alive(): + return + + process.terminate() + process.join(timeout=timeout) + if process.is_alive(): + process.kill() + process.join() diff --git a/nemo_rl/models/generation/redesign/ray_http_utils.py b/nemo_rl/models/generation/redesign/ray_http_utils.py deleted file mode 100644 index f9b7afac32..0000000000 --- a/nemo_rl/models/generation/redesign/ray_http_utils.py +++ /dev/null @@ -1,410 +0,0 @@ -import asyncio -import io -import ipaddress -import json -import logging -import multiprocessing -import os -import random -import socket -from multiprocessing.reduction import ForkingPickler -from typing import Callable, Union - -import httpx -import pybase64 -import ray -import torch -from torch.multiprocessing import reductions - -from nemo_rl.models.generation.redesign.config import SGLangConfig - -logger = logging.getLogger(__name__) - - -class RayActor: - """Base class for Ray actors providing node IP / free port helpers.""" - - @staticmethod - def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): - return get_current_node_ip(), get_free_port( - start_port=start_port, consecutive=consecutive - ) - - def get_master_addr_and_port(self): - return self.master_addr, self.master_port - - -class MultiprocessingSerializer: # pragma: no cover - """Serialize/deserialize Python objects using ForkingPickler for IPC. - - This class enables serialization of objects (including CUDA tensors with IPC - handles) for transfer between processes via HTTP or other mechanisms. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 - """ - - @staticmethod - def serialize(obj, output_str: bool = False): - """Serialize a Python object using ForkingPickler. - - Args: - obj: The object to serialize. - output_str (bool): If True, return a base64-encoded string instead of raw bytes. - - Returns: - bytes or str: The serialized object. - """ - buf = io.BytesIO() - ForkingPickler(buf).dump(obj) - buf.seek(0) - output = buf.read() - - if output_str: - # Convert bytes to base64-encoded string - output = pybase64.b64encode(output).decode("utf-8") - - return output - - @staticmethod - def deserialize(data): - """Deserialize a previously serialized object. - - Args: - data (bytes or str): The serialized data, optionally base64-encoded. - - Returns: - The deserialized Python object. - """ - if isinstance(data, str): - # Decode base64 string to bytes - data = pybase64.b64decode(data, validate=True) - - return ForkingPickler.loads(data) - - - -NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", - "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", - "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", - "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", -] - - -@ray.remote -class Lock(RayActor): - def __init__(self): - self._locked = False # False: unlocked, True: locked - - def acquire(self): - """ - Try to acquire the lock. Returns True if acquired, False otherwise. - Caller should retry until it returns True. - """ - if not self._locked: - self._locked = True - return True - return False - - def release(self): - """Release the lock, allowing others to acquire.""" - assert self._locked, "Lock is not acquired, cannot release." - self._locked = False - - - -def find_available_port(base_port: int): - port = base_port + random.randint(100, 1000) - while True: - if is_port_available(port): - return port - if port < 60000: - port += 42 - else: - port -= 43 - -def is_port_available(port): - """Return whether a port is available.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("", port)) - s.listen(1) - return True - except OSError: - return False - except OverflowError: - return False - -def get_host_info(): - hostname = socket.gethostname() - - def _is_loopback(ip): - return ip.startswith("127.") or ip == "::1" - - def _resolve_ip(family, test_target_ip): - """ - Attempt to get the local LAN IP for the specific family (IPv4/IPv6). - Strategy: UDP Probe (Preferred) -> Hostname Resolution (Fallback) -> None - """ - - # Strategy 1: UDP Connect Probe (Most accurate, relies on routing table) - # Useful when the machine has a default gateway or internet access. - try: - with socket.socket(family, socket.SOCK_DGRAM) as s: - # The IP doesn't need to be reachable, but the routing table must exist. - s.connect((test_target_ip, 80)) - ip = s.getsockname()[0] - if not _is_loopback(ip): - return ip - except Exception: - pass # Route unreachable or network error, move to next strategy. - - # Strategy 2: Hostname Resolution (Fallback for offline clusters) - # Useful for offline environments where UDP connect fails but /etc/hosts is configured. - try: - # getaddrinfo allows specifying the family (AF_INET or AF_INET6) - # Result format: [(family, type, proto, canonname, sockaddr), ...] - infos = socket.getaddrinfo(hostname, None, family=family, type=socket.SOCK_STREAM) - - for info in infos: - ip = info[4][0] # The first element of sockaddr is the IP - # Must filter out loopback addresses to avoid "127.0.0.1" issues - if not _is_loopback(ip): - return ip - except Exception: - pass - - return None - - prefer_ipv6 = os.getenv("MILES_PREFER_IPV6", "0").lower() in ("1", "true", "yes", "on") - local_ip = None - final_fallback = "127.0.0.1" - - if prefer_ipv6: - # [Strict Mode] IPv6 Only - # 1. Try UDP V6 Probe - # 2. Try Hostname Resolution (V6) - # If failed, fallback to V6 loopback. Never mix with V4. - local_ip = _resolve_ip(socket.AF_INET6, "2001:4860:4860::8888") - final_fallback = "::1" - else: - # [Strict Mode] IPv4 Only (Default) - # 1. Try UDP V4 Probe - # 2. Try Hostname Resolution (V4) - # If failed, fallback to V4 loopback. Never mix with V6. - local_ip = _resolve_ip(socket.AF_INET, "8.8.8.8") - final_fallback = "127.0.0.1" - - return hostname, local_ip or final_fallback - -def get_current_node_ip(): - address = ray._private.services.get_node_ip_address() - # strip ipv6 address - address = address.strip("[]") - return address - -def get_free_port(start_port=10000, consecutive=1): - # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available - port = start_port - while not all(is_port_available(port + i) for i in range(consecutive)): - port += 1 - return port - -def _wrap_ipv6(host): - """Wrap IPv6 address in [] if needed.""" - try: - ipaddress.IPv6Address(host.strip("[]")) - return f"[{host.strip('[]')}]" - except ipaddress.AddressValueError: - return host - - -def run_router(args): - try: - from sglang_router.launch_router import launch_router - - router = launch_router(args) - if router is None: - return 1 - return 0 - except Exception as e: - logger.info(e) - return 1 - - -def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None: - """Terminate a process gracefully, with forced kill as fallback. - - Args: - process: The process to terminate - timeout: Seconds to wait for graceful termination before forcing kill - """ - if not process.is_alive(): - return - - process.terminate() - process.join(timeout=timeout) - if process.is_alive(): - process.kill() - process.join() - - -_http_client: httpx.AsyncClient | None = None -_client_concurrency: int = 0 - -# Optional Ray-based distributed POST dispatch -_distributed_post_enabled: bool = False -_post_actors: list[object] = [] -_post_actor_idx: int = 0 - - -def _next_actor(): - global _post_actor_idx - if not _post_actors: - return None - actor = _post_actors[_post_actor_idx % len(_post_actors)] - _post_actor_idx = (_post_actor_idx + 1) % len(_post_actors) - return actor - - -async def _post(client, url, payload, max_retries=60, action="post"): - retry_count = 0 - while retry_count < max_retries: - try: - if action in ("delete", "get"): - assert not payload - response = await getattr(client, action)(url) - else: - response = await getattr(client, action)(url, json=payload or {}) - response.raise_for_status() - try: - output = response.json() - except json.JSONDecodeError: - output = response.text - except Exception as e: - retry_count += 1 - - if isinstance(e, httpx.HTTPStatusError): - response_text = e.response.text - else: - response_text = None - - logger.info( - f"Error: {e}, retrying... (attempt {retry_count}/{max_retries}, url={url}, response={response_text})" - ) - if retry_count >= max_retries: - logger.info(f"Max retries ({max_retries}) reached, failing... (url={url})") - raise e - await asyncio.sleep(1) - continue - break - - return output - - -def init_http_client(args: SGLangConfig): - """Initialize HTTP client and optionally enable distributed POST via Ray.""" - global _http_client, _client_concurrency, _distributed_post_enabled - if not args.get("sglang_server").get("num_gpus"): - return - - _client_concurrency = ( - args["sglang_server"]["sglang_server_concurrency"] - * args["sglang_server"]["num_gpus"] - // args["sglang_server"]["num_gpus_per_engine"] - ) - if _http_client is None: - _http_client = httpx.AsyncClient( - limits=httpx.Limits(max_connections=_client_concurrency), - timeout=httpx.Timeout(None), - ) - - # Optionally initialize distributed POST via Ray without changing interfaces - if args.get("sglang_router").get("use_distributed_post"): - _init_ray_distributed_post(args) - _distributed_post_enabled = True - - -def _init_ray_distributed_post(args: SGLangConfig): - """Initialize one or more Ray async actors per node for HTTP POST. - - Uses NodeAffinitySchedulingStrategy to place actors on distinct nodes. - Controlled by MILES_HTTP_POST_ACTORS_PER_NODE. - """ - global _post_actors - if _post_actors: - return # Already initialized - - import ray - from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy - - # Discover alive nodes - nodes = [n for n in ray.nodes() if n.get("Alive")] - if not nodes: - raise RuntimeError("No alive Ray nodes to place HTTP POST actors.") - - # Define the async actor - @ray.remote - class _HttpPosterActor: - def __init__(self, concurrency: int): - # Lazy creation to this actor's event loop - self._client = httpx.AsyncClient( - limits=httpx.Limits(max_connections=max(1, concurrency)), - timeout=httpx.Timeout(None), - ) - - async def do_post(self, url, payload, max_retries=60, action="post"): - return await _post(self._client, url, payload, max_retries, action=action) - - # Create actors per node - created = [] - # Distribute client concurrency across actors (at least 1 per actor) - per_actor_conc = (_client_concurrency + len(nodes)) // len(nodes) - - for node in nodes: - node_id = node["NodeID"] - scheduling = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) - for _ in range(args["sglang_server"]["num_gpus_per_engine"]): - actor = _HttpPosterActor.options( - name=None, - lifetime="detached", - scheduling_strategy=scheduling, - max_concurrency=per_actor_conc, - # Use tiny CPU to schedule - num_cpus=0.001, - ).remote(per_actor_conc) - created.append(actor) - - _post_actors = created - - -# TODO may generalize the name since it now contains http DELETE/GET etc (with retries and remote-execution) -async def post(url, payload, max_retries=60, action="post"): - # If distributed mode is enabled and actors exist, dispatch via Ray. - if _distributed_post_enabled and _post_actors: - try: - import ray - - actor = _next_actor() - if actor is not None: - # Use a thread to avoid blocking the event loop on ray.get - obj_ref = actor.do_post.remote(url, payload, max_retries, action=action) - return await asyncio.to_thread(ray.get, obj_ref) - except Exception as e: - logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") - # fall through to local - - return await _post(_http_client, url, payload, max_retries, action=action) - - -# TODO unify w/ `post` to add retries and remote-execution -async def get(url): - response = await _http_client.get(url) - response.raise_for_status() - output = response.json() - return output diff --git a/nemo_rl/models/generation/redesign/ray_utils.py b/nemo_rl/models/generation/redesign/ray_utils.py new file mode 100644 index 0000000000..f7d7d3a1ea --- /dev/null +++ b/nemo_rl/models/generation/redesign/ray_utils.py @@ -0,0 +1,147 @@ +import ipaddress +import os +import random +import socket + +import ray + + +class RayActor: + """Base class for Ray actors providing node IP / free port helpers.""" + + @staticmethod + def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): + return get_current_node_ip(), get_free_port( + start_port=start_port, consecutive=consecutive + ) + + def get_master_addr_and_port(self): + return self.master_addr, self.master_port + + +@ray.remote +class Lock(RayActor): + def __init__(self): + self._locked = False # False: unlocked, True: locked + + def acquire(self): + """ + Try to acquire the lock. Returns True if acquired, False otherwise. + Caller should retry until it returns True. + """ + if not self._locked: + self._locked = True + return True + return False + + def release(self): + """Release the lock, allowing others to acquire.""" + assert self._locked, "Lock is not acquired, cannot release." + self._locked = False + + +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except OSError: + return False + except OverflowError: + return False + +def get_host_info(): + hostname = socket.gethostname() + + def _is_loopback(ip): + return ip.startswith("127.") or ip == "::1" + + def _resolve_ip(family, test_target_ip): + """ + Attempt to get the local LAN IP for the specific family (IPv4/IPv6). + Strategy: UDP Probe (Preferred) -> Hostname Resolution (Fallback) -> None + """ + + # Strategy 1: UDP Connect Probe (Most accurate, relies on routing table) + # Useful when the machine has a default gateway or internet access. + try: + with socket.socket(family, socket.SOCK_DGRAM) as s: + # The IP doesn't need to be reachable, but the routing table must exist. + s.connect((test_target_ip, 80)) + ip = s.getsockname()[0] + if not _is_loopback(ip): + return ip + except Exception: + pass # Route unreachable or network error, move to next strategy. + + # Strategy 2: Hostname Resolution (Fallback for offline clusters) + # Useful for offline environments where UDP connect fails but /etc/hosts is configured. + try: + # getaddrinfo allows specifying the family (AF_INET or AF_INET6) + # Result format: [(family, type, proto, canonname, sockaddr), ...] + infos = socket.getaddrinfo(hostname, None, family=family, type=socket.SOCK_STREAM) + + for info in infos: + ip = info[4][0] # The first element of sockaddr is the IP + # Must filter out loopback addresses to avoid "127.0.0.1" issues + if not _is_loopback(ip): + return ip + except Exception: + pass + + return None + + prefer_ipv6 = os.getenv("PREFER_IPV6", "0").lower() in ("1", "true", "yes", "on") + local_ip = None + final_fallback = "127.0.0.1" + + if prefer_ipv6: + # [Strict Mode] IPv6 Only + # 1. Try UDP V6 Probe + # 2. Try Hostname Resolution (V6) + # If failed, fallback to V6 loopback. Never mix with V4. + local_ip = _resolve_ip(socket.AF_INET6, "2001:4860:4860::8888") + final_fallback = "::1" + else: + # [Strict Mode] IPv4 Only (Default) + # 1. Try UDP V4 Probe + # 2. Try Hostname Resolution (V4) + # If failed, fallback to V4 loopback. Never mix with V6. + local_ip = _resolve_ip(socket.AF_INET, "8.8.8.8") + final_fallback = "127.0.0.1" + + return hostname, local_ip or final_fallback + +def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + address = address.strip("[]") + return address + +def get_free_port(start_port=10000, consecutive=1): + # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available + port = start_port + while not all(is_port_available(port + i) for i in range(consecutive)): + port += 1 + return port + +def _wrap_ipv6(host): + """Wrap IPv6 address in [] if needed.""" + try: + ipaddress.IPv6Address(host.strip("[]")) + return f"[{host.strip('[]')}]" + except ipaddress.AddressValueError: + return host diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 3db55a2a17..51b94e2238 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -34,15 +34,19 @@ ) from nemo_rl.models.generation.redesign.async_utils import run from nemo_rl.models.generation.redesign.config import SGLangConfig -from nemo_rl.models.generation.redesign.ray_http_utils import ( +from nemo_rl.models.generation.redesign.http_utils import ( + init_http_client, + post, +) +from nemo_rl.models.generation.redesign.misc import ( NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, + run_router, +) +from nemo_rl.models.generation.redesign.ray_utils import ( Lock, _wrap_ipv6, find_available_port, get_host_info, - init_http_client, - post, - run_router, ) from nemo_rl.models.generation.redesign.sglang_worker import SGLangEngine @@ -850,7 +854,7 @@ def addr(): # --------------------------------------------------------------------------- def _start_router(args: SGLangConfig) -> tuple[str, int]: - """Start sgl router or miles router and return (router_ip, router_port). + """Start sgl router return (router_ip, router_port). If ``args.sglang_router_ip`` is already set and ``force_new`` is False, skip launching and return the existing values. diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 76506ce875..873c6915cc 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -14,7 +14,7 @@ from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError -from nemo_rl.models.generation.redesign.ray_http_utils import ( +from nemo_rl.models.generation.redesign.ray_utils import ( get_current_node_ip, get_free_port, get_host_info, From 2568512027246eaedd4cd834cc3e8ffe209861e0 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 14:24:14 -0700 Subject: [PATCH 13/36] update --- nemo_rl/models/generation/redesign/sglang_generation.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 51b94e2238..848d7f6806 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -143,6 +143,13 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis }.items() } + # Explicitly pass CUDA_VISIBLE_DEVICES through to the engine actor so + # all engines see the same global value (Ray would otherwise remap it + # because we set the NOSET_* flags above). + global_cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if global_cvd: + env_vars["CUDA_VISIBLE_DEVICES"] = global_cvd + rollout_engine = RolloutRayActor.options( num_cpus=num_cpus, num_gpus=num_gpus, From 4f94d6e23afc42ca7efbb2b60a9b4e72c0e4930f Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 14:40:14 -0700 Subject: [PATCH 14/36] update --- .../models/generation/redesign/sglang_generation.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 848d7f6806..9195b972cb 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -240,15 +240,7 @@ def onload_weights(self): def onload_kv(self): handles = self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH]) return ray.get(handles) if handles else [] - - def onload_weights_from_disk(self): - """Reload weights from ``model_path`` for non-updatable groups.""" - if not self.needs_offload or not self.model_path: - return [] - return [ - engine.update_weights_from_disk.remote(self.model_path) for engine in self.engines if engine is not None - ] - + class SGLangGeneration(GenerationInterface): """The class to run rollout and convert rollout data to training data.""" From a478c692400b34559c23691569e61bce1a8e0ec2 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 15:49:37 -0700 Subject: [PATCH 15/36] update --- nemo_rl/models/generation/redesign/config.py | 7 +- .../generation/redesign/fault_tolerance.py | 177 ++++++++++++++++++ .../models/generation/redesign/http_utils.py | 12 +- .../generation/redesign/sglang_generation.py | 103 +++++++--- .../generation/redesign/sglang_worker.py | 7 +- 5 files changed, 275 insertions(+), 31 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/fault_tolerance.py diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py index 7f74f47529..e74bce82b0 100644 --- a/nemo_rl/models/generation/redesign/config.py +++ b/nemo_rl/models/generation/redesign/config.py @@ -93,12 +93,17 @@ class SglangSpecificArgs(TypedDict): enable_fast_load: NotRequired[bool] # Server warmup skip_server_warmup: NotRequired[bool] + # Fault tolerance + use_fault_tolerance: NotRequired[bool] + rollout_health_check_interval: NotRequired[int] + rollout_health_check_timeout: NotRequired[int] + rollout_health_check_first_wait: NotRequired[int] class SGLangServer(TypedDict): needs_offload: bool sglang_server_concurrency: int num_gpus: NotRequired[int] - num_gpus_per_engine: NotRequired[int] + num_gpus_per_engine: NotRequired[int] class SGLangRouter(TypedDict): sglang_router_ip: NotRequired[str] diff --git a/nemo_rl/models/generation/redesign/fault_tolerance.py b/nemo_rl/models/generation/redesign/fault_tolerance.py new file mode 100644 index 0000000000..0a3e89538e --- /dev/null +++ b/nemo_rl/models/generation/redesign/fault_tolerance.py @@ -0,0 +1,177 @@ +import logging +import threading + +import ray + +logger = logging.getLogger(__name__) + +from nemo_rl.models.generation.redesign.config import SGLangConfig + +class RolloutHealthMonitor: + """Health monitor for rollout engines. + + The monitor runs continuously once started, but can be paused/resumed + based on whether the engines are offloaded (cannot health check when offloaded). + + Lifecycle: + - start(): Start the monitor thread (called once during initialization) + - pause(): Pause health checking (called when offloading engines) + - resume(): Resume health checking (called when onloading engines) + - stop(): Stop the monitor thread completely (called during dispose) + """ + + def __init__(self, server_group, sglang_cfg: SGLangConfig): + self._server_group = server_group + + self._thread = None + self._stop_event = None + self._pause_event = None # When set, health checking is paused + self._check_interval = sglang_cfg["sglang_cfg"]["rollout_health_check_interval"] + self._check_timeout = sglang_cfg["sglang_cfg"]["rollout_health_check_timeout"] + self._check_first_wait = sglang_cfg["sglang_cfg"]["rollout_health_check_first_wait"] + self._need_first_wait = True # Need to wait after each resume + self._is_checking_enabled = False # Track if health checking should be active + + def start(self) -> bool: + """Start the health monitor thread. Called once during initialization. + + Returns: + True if the monitor was started, False if there are no engines to monitor. + """ + if not self._server_group.all_engines: + return False + + if self._thread is not None: + logger.warning("Health monitor thread is already running.") + return True + + logger.info("Starting RolloutHealthMonitor...") + self._stop_event = threading.Event() + self._pause_event = threading.Event() + self._pause_event.set() # Start in paused state until resume() is called + self._thread = threading.Thread( + target=self._health_monitor_loop, + name="RolloutHealthMonitor", + daemon=True, + ) + self._thread.start() + logger.info("RolloutHealthMonitor started (in paused state).") + return True + + def stop(self) -> None: + """Stop the health monitor thread completely. Called during dispose.""" + if not self._thread: + return + + logger.info("Stopping RolloutHealthMonitor...") + assert self._stop_event is not None + self._stop_event.set() + # Also clear pause to let the thread exit + if self._pause_event: + self._pause_event.clear() + timeout = self._check_timeout + self._check_interval + 5 + self._thread.join(timeout=timeout) + if self._thread.is_alive(): + logging.warning("Rollout health monitor thread did not terminate within %.1fs", timeout) + else: + logger.info("RolloutHealthMonitor stopped.") + + self._thread = None + self._stop_event = None + self._pause_event = None + self._is_checking_enabled = False + + def pause(self) -> None: + """Pause health checking. Called when engines are offloaded.""" + if self._pause_event is None: + return + logger.info("Pausing health monitor...") + self._pause_event.set() + self._is_checking_enabled = False + + def resume(self) -> None: + """Resume health checking. Called when engines are onloaded.""" + if self._pause_event is None: + return + logger.info("Resuming health monitor...") + self._need_first_wait = True # Need to wait after each resume + self._pause_event.clear() + self._is_checking_enabled = True + + def is_checking_enabled(self) -> bool: + """Return whether health checking is currently enabled (not paused).""" + return self._is_checking_enabled + + def _health_monitor_loop(self) -> None: + assert self._stop_event is not None + assert self._pause_event is not None + + while not self._stop_event.is_set(): + # Wait while paused + while self._pause_event.is_set() and not self._stop_event.is_set(): + self._stop_event.wait(timeout=0.5) + + if self._stop_event.is_set(): + break + + # Do first wait after each resume (for large MoE models to be ready) + if self._need_first_wait: + logger.info(f"Health monitor doing first wait after resume: {self._check_first_wait}s") + if self._stop_event.wait(self._check_first_wait): + logger.info("Health monitor stopped during first wait.") + break + if self._pause_event.is_set(): + # Got paused during first wait, skip this round and wait again next resume + logger.info("Health monitor paused during first wait, will wait again next resume.") + continue + self._need_first_wait = False + + # Run health checks + if not self._pause_event.is_set() and not self._stop_event.is_set(): + self._run_health_checks() + + # Wait for next check interval + if self._stop_event.wait(self._check_interval): + break + + def _run_health_checks(self) -> None: + for rollout_engine_id, engine in enumerate(self._server_group.engines): + if self._stop_event is not None and self._stop_event.is_set(): + break + if self._pause_event is not None and self._pause_event.is_set(): + break + self._check_engine_health(rollout_engine_id, engine) + + def _check_engine_health(self, rollout_engine_id, engine) -> None: + if engine is None: + logger.info(f"Skipping health check for engine {rollout_engine_id} (None)") + return + + try: + ray.get(engine.health_generate.remote(timeout=self._check_timeout)) + except Exception as e: + logger.error( + f"Health check failed for rollout engine {rollout_engine_id} (ray timeout or error). Killing actor. Exception: {e}" + ) + self._kill_engine(rollout_engine_id=rollout_engine_id) + else: + logger.debug(f"Health check passed for rollout engine {rollout_engine_id}") + + def _kill_engine(self, rollout_engine_id: int): + logger.info(f"Killing server group {rollout_engine_id}...") + for i in range( + rollout_engine_id * self._server_group.nodes_per_engine, + (rollout_engine_id + 1) * self._server_group.nodes_per_engine, + ): + engine = self._server_group.all_engines[i] + if engine: + logger.info(f"Shutting down and killing engine at index {i}") + try: + ray.get(engine.shutdown.remote()) + ray.kill(engine) + logger.info(f"Successfully killed engine at index {i}") + except Exception as e: + logger.warning(f"Fail to kill engine at index {i} (e: {e})") + else: + logger.info(f"Engine at index {i} is already None") + self._server_group.all_engines[i] = None diff --git a/nemo_rl/models/generation/redesign/http_utils.py b/nemo_rl/models/generation/redesign/http_utils.py index ccc83a4dfe..cae5e3dbf6 100644 --- a/nemo_rl/models/generation/redesign/http_utils.py +++ b/nemo_rl/models/generation/redesign/http_utils.py @@ -65,13 +65,14 @@ async def _post(client, url, payload, max_retries=60, action="post"): def init_http_client(args: SGLangConfig): """Initialize HTTP client and optionally enable distributed POST via Ray.""" global _http_client, _client_concurrency, _distributed_post_enabled - if not args.get("sglang_server").get("num_gpus"): + server_cfg = args.get("sglang_server") or {} + if not server_cfg.get("num_gpus"): return _client_concurrency = ( - args["sglang_server"]["sglang_server_concurrency"] - * args["sglang_server"]["num_gpus"] - // args["sglang_server"]["num_gpus_per_engine"] + server_cfg["sglang_server_concurrency"] + * server_cfg["num_gpus"] + // server_cfg["num_gpus_per_engine"] ) if _http_client is None: _http_client = httpx.AsyncClient( @@ -80,7 +81,8 @@ def init_http_client(args: SGLangConfig): ) # Optionally initialize distributed POST via Ray without changing interfaces - if args.get("sglang_router").get("use_distributed_post"): + router_cfg = args.get("sglang_router") or {} + if router_cfg.get("use_distributed_post"): _init_ray_distributed_post(args) _distributed_post_enabled = True diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 9195b972cb..b10dcd46da 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -41,6 +41,7 @@ from nemo_rl.models.generation.redesign.misc import ( NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, run_router, + terminate_process, ) from nemo_rl.models.generation.redesign.ray_utils import ( Lock, @@ -50,6 +51,8 @@ ) from nemo_rl.models.generation.redesign.sglang_worker import SGLangEngine +from nemo_rl.models.generation.redesign.fault_tolerance import RolloutHealthMonitor + logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) @@ -77,6 +80,10 @@ class ServerGroup: router_port: int | None = None cluster_cfg: Any = None sglang_cfg: Any = None + # Router subprocess handle. Only set when _start_router actually spawned + # the router (i.e. sglang_router_ip was not already configured). Kept so + # SGLangGeneration.shutdown() can terminate it cleanly. + router_process: multiprocessing.Process | None = None @property def nodes_per_engine(self): @@ -192,6 +199,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis ] return init_handles, port_cursors + def recover(self): """Recover dead engines across all active groups, overlapping init.""" dead_indices = [i for i, engine in enumerate(self.all_engines) if engine is None] @@ -264,6 +272,13 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() + monitor = None + if sglang_cfg["sglang_cfg"].get("use_fault_tolerance"): + monitor = RolloutHealthMonitor(self.server_group, sglang_cfg) + monitor.start() + self._health_monitor = monitor + + @property def rollout_engines(self): """All node-0 engines across all servers / models.""" @@ -309,10 +324,7 @@ def recover_updatable_engines(self): server_group = self.server_group if server_group is None: - engines = server_group.engines if server_group else [] - gpu_counts = server_group.engine_gpu_counts if server_group else [] - gpu_offsets = server_group.engine_gpu_offsets if server_group else [] - return engines, self.rollout_engine_lock, (server_group.num_new_engines if server_group else 0), gpu_counts, gpu_offsets + return [], self.rollout_engine_lock, 0, [], [] server_group.recover() return ( @@ -329,7 +341,48 @@ def clear_updatable_num_new_engines(self): self.server_group.num_new_engines = 0 def check_weights(self, action: str): - return ray.get([engine.check_weights.remote(action=action) for engine in self.rollout_engines]) + return ray.get( + [ + engine.check_weights.remote(action=action) + for engine in self.rollout_engines + if engine is not None + ] + ) + + def health_monitoring_pause(self) -> None: + if self._health_monitor: + self._health_monitor.pause() + + def health_monitoring_resume(self) -> None: + if self._health_monitor: + self._health_monitor.resume() + + def shutdown(self) -> bool: + if self._health_monitor: + self._health_monitor.stop() + + ok = True + engines = [e for e in self.server_group.all_engines if e is not None] + if engines: + try: + ray.get([e.shutdown.remote() for e in engines]) + except Exception as e: + logger.warning(f"Engine shutdown failed: {e}") + ok = False + + router_process = self.server_group.router_process + if router_process is not None: + try: + terminate_process(router_process, timeout=3.0) + except Exception as e: + logger.warning(f"Router terminate failed: {e}") + ok = False + self.server_group.router_process = None + + return ok + + def __del__(self) -> None: + self.shutdown() def _merge_stop_strings(self, batch_stop_strings) -> list[list[str]]: """Merge stop strings from config and batch. @@ -749,9 +802,10 @@ async def generate_one_sample(sglang_router_ip, sglang_router_port, sampling_par else: response_tokens, response_log_probs = [], [] - response_truncated = False - if output["meta_info"]["output_token_logprobs"] is not None and output["meta_info"]["output_token_logprobs"] == "length": - response_truncated = True + # SGLang reports the termination reason under meta_info.finish_reason.type; + # "length" means the decoder hit max_new_tokens before EOS. + finish_reason = output["meta_info"].get("finish_reason") or {} + response_truncated = finish_reason.get("type") == "length" return index, response_tokens, response_log_probs, response_truncated @@ -852,31 +906,37 @@ def addr(): # Router + server bootstrap # --------------------------------------------------------------------------- -def _start_router(args: SGLangConfig) -> tuple[str, int]: - """Start sgl router return (router_ip, router_port). +def _start_router( + sglang_cfg: SGLangConfig, +) -> tuple[str, int, multiprocessing.Process | None]: + """Start sgl router, returning ``(router_ip, router_port, process)``. - If ``args.sglang_router_ip`` is already set and ``force_new`` is False, - skip launching and return the existing values. + If ``sglang_router.sglang_router_ip`` is already set, reuse it and return + ``process=None`` (we do not own that router and must not terminate it). + Otherwise spawn a new router process and return its handle so the caller + can shut it down explicitly. """ - if args.sglang_router_ip is not None: - return args.sglang_router_ip, args.sglang_router_port + router_cfg = sglang_cfg["sglang_router"] + if router_cfg["sglang_router_ip"] is not None: + return router_cfg["sglang_router_ip"], router_cfg["sglang_router_port"], None router_ip = _wrap_ipv6(get_host_info()[1]) - router_port = args.sglang_router_port + router_port = router_cfg["sglang_router_port"] if router_port is None: router_port = find_available_port(random.randint(3000, 4000)) from sglang_router.launch_router import RouterArgs - # pass from router_args = RouterArgs() router_args.host = router_ip router_args.port = router_port - if args["sglang_router"]["router_policy"] is not None: - router_args.router_policy = args["sglang_router"]["router_policy"] + if router_cfg.get("router_policy") is not None: + router_args.router_policy = router_cfg["router_policy"] router_args.prometheus_port = find_available_port(random.randint(4000, 5000)) router_args.log_level = "warn" - router_args.request_timeout_secs = args.sglang_router_request_timeout_secs + request_timeout_secs = router_cfg.get("sglang_router_request_timeout_secs") + if request_timeout_secs is not None: + router_args.request_timeout_secs = request_timeout_secs logger.info(f"Launch router with args: {router_args}") @@ -889,7 +949,7 @@ def _start_router(args: SGLangConfig) -> tuple[str, int]: time.sleep(3) assert process.is_alive() logger.info(f"Router launched at {router_ip}:{router_port}") - return router_ip, router_port + return router_ip, router_port, process def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: """Start rollout servers: one per model, each with its own router. @@ -900,7 +960,7 @@ def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: engine_offset = 0 gpu_offset = 0 - router_ip, router_port = _start_router(sglang_cfg) + router_ip, router_port, router_process = _start_router(sglang_cfg) sglang_cfg["sglang_router"]["sglang_router_ip"] = router_ip sglang_cfg["sglang_router"]["sglang_router_port"] = router_port @@ -929,6 +989,7 @@ def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: router_port=router_port, cluster_cfg=cluster_cfg, sglang_cfg=sglang_cfg, + router_process=router_process, ) handles, port_cursors = server_group.start_engines(port_cursors) diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 873c6915cc..b33b5792a3 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -53,7 +53,7 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: p.start() if server_args.node_rank != 0: - return + return p _wait_server_healthy( base_url=server_args.url(), @@ -124,8 +124,8 @@ def init( router_port=None, ): - self.router_ip = router_ip if router_ip is not None else self.sglang_cfg["sglang_cfg"]["sglang_router_ip"] - self.router_port = router_port if router_port is not None else self.sglang_cfg["sglang_cfg"]["sglang_router_port"] + self.router_ip = router_ip if router_ip is not None else self.sglang_cfg["sglang_router"]["sglang_router_ip"] + self.router_port = router_port if router_port is not None else self.sglang_cfg["sglang_router"]["sglang_router_port"] host = host or get_host_info()[1] @@ -278,7 +278,6 @@ def flush_cache(self): raise TimeoutError("Timeout while flushing cache.") def shutdown(self): - logger.info(f"Shutdown engine {self.server_host}:{self.server_port}...") if self.node_rank == 0: worker_url = f"http://{self.server_host}:{self.server_port}" From a943224adf44867c091c77fa5310020a6bbaeef6 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Fri, 10 Apr 2026 20:12:15 -0700 Subject: [PATCH 16/36] update --- .../models/generation/redesign/sglang_generation.py | 10 +++++++--- nemo_rl/models/generation/redesign/sglang_worker.py | 5 +++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index b10dcd46da..2e26ab2db5 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -26,6 +26,7 @@ RayVirtualCluster, get_reordered_bundle_and_gpu_ids, ) +from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationInterface, @@ -49,10 +50,12 @@ find_available_port, get_host_info, ) -from nemo_rl.models.generation.redesign.sglang_worker import SGLangEngine +from nemo_rl.models.generation.redesign.sglang_worker import SGLangGenerationWorker from nemo_rl.models.generation.redesign.fault_tolerance import RolloutHealthMonitor +from nemo_rl.utils.nsys import wrap_with_nvtx_name + logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) @@ -117,7 +120,6 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis num_gpu_per_engine = min(self.num_gpus_per_engine, self.num_gpus_per_node) pg, reordered_bundle_indices, reordered_gpu_ids = self.pg - RolloutRayActor = ray.remote(SGLangEngine) rollout_engines = [] for i in range(len(self.all_engines)): @@ -157,12 +159,13 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis if global_cvd: env_vars["CUDA_VISIBLE_DEVICES"] = global_cvd - rollout_engine = RolloutRayActor.options( + rollout_engine = SGLangGenerationWorker.options( num_cpus=num_cpus, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, runtime_env={ "env_vars": env_vars, + **get_nsight_config_if_pattern_matches("sglang_generation_worker"), }, ).remote( self.cluster_cfg, @@ -458,6 +461,7 @@ def _build_sampling_params( return sampling_params + @wrap_with_nvtx_name("sglang_genertion/generate") def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index b33b5792a3..77786b7abc 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -99,7 +99,8 @@ def _wait_server_healthy(base_url, api_key, is_process_alive): time.sleep(2) -class SGLangEngine: +@ray.remote # pragma: no cover +class SGLangGenerationWorker: def __init__( self, cluster_cfg, @@ -107,7 +108,7 @@ def __init__( rank: int, base_gpu_id: int | None = None, num_gpus_per_engine: int | None = None, - ): + ): self.cluster_cfg = cluster_cfg self.sglang_cfg = sglang_cfg self.rank = rank From 6f7c371b1d936f1fb5bc67d04af9341b95ff912d Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sat, 11 Apr 2026 09:27:30 -0700 Subject: [PATCH 17/36] update --- .../generation/redesign/sglang_generation.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 2e26ab2db5..e450eeeead 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -121,7 +121,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis num_gpu_per_engine = min(self.num_gpus_per_engine, self.num_gpus_per_node) pg, reordered_bundle_indices, reordered_gpu_ids = self.pg - rollout_engines = [] + local_all_engines = [] for i in range(len(self.all_engines)): if self.all_engines[i] is not None: continue @@ -175,10 +175,10 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis num_gpus_per_engine=self.num_gpus_per_engine, ) - rollout_engines.append((global_rank, rollout_engine)) + local_all_engines.append((global_rank, rollout_engine)) self.all_engines[i] = rollout_engine - self.num_new_engines = len(rollout_engines) + self.num_new_engines = len(local_all_engines) if self.num_new_engines == 0: return [], port_cursors @@ -187,7 +187,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( cluster_cfg=self.cluster_cfg, sglang_cfg=self.sglang_cfg, - rollout_engines=rollout_engines, + local_all_engines=local_all_engines, rank_offset=self.rank_offset, base_port=base_port, ) @@ -198,7 +198,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis router_ip=self.router_ip, router_port=self.router_port, ) - for rank, engine in rollout_engines + for rank, engine in local_all_engines ] return init_handles, port_cursors @@ -281,7 +281,6 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan monitor.start() self._health_monitor = monitor - @property def rollout_engines(self): """All node-0 engines across all servers / models.""" @@ -297,10 +296,11 @@ def get_updatable_engines_and_lock(self): return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets def offload(self, tags: list[str] | None = None): + """All node-0 engines across all servers / models.""" if tags is not None: handles = [ engine.release_memory_occupation.remote(tags=tags) - for engine in self.rollout_engines + for engine in self.rollout_engines if engine is not None ] return ray.get(handles) if handles else [] @@ -344,6 +344,7 @@ def clear_updatable_num_new_engines(self): self.server_group.num_new_engines = 0 def check_weights(self, action: str): + """All node-0 engines across all servers / models.""" return ray.get( [ engine.check_weights.remote(action=action) @@ -751,15 +752,15 @@ async def generate_async( output_ids_single_item[:input_length] = original_input_ids_single_row[ :input_length ] - if new_tokens: - output_ids_single_item[input_length:unpadded_length] = torch.tensor( - new_tokens, dtype=dtype, device=device - ) - # Logprobs: zeros for input tokens, raw floats at generated positions. logprobs_single_item = torch.zeros( (1, unpadded_length), dtype=torch.float32, device=device ) + + if new_tokens: + output_ids_single_item[input_length:unpadded_length] = torch.tensor( + new_tokens, dtype=dtype, device=device + ) if new_logprobs: logprobs_single_item[ 0, input_length : input_length + len(new_logprobs) @@ -822,7 +823,7 @@ def _allocate_rollout_engine_addr_and_ports_normal( *, cluster_cfg, sglang_cfg, - rollout_engines, + local_all_engines, rank_offset=0, base_port=15000, ): @@ -846,7 +847,7 @@ def _allocate_rollout_engine_addr_and_ports_normal( node_port_cursor: dict[int, int] = {} visited_nodes = set() - for rank, engine in rollout_engines: + for rank, engine in local_all_engines: local_rank = rank - rank_offset node_index = local_rank // num_engines_per_node if node_index in visited_nodes: @@ -899,7 +900,7 @@ def addr(): for i in range(num_engines_on_this_node): addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(30 + sglang_dp_size)}" - for i, _ in rollout_engines: + for i, _ in local_all_engines: for key in ["port", "nccl_port", "dist_init_addr"]: assert key in addr_and_ports[i], f"Engine {i} {key} is not set." logger.info(f"Ports for engine {i}: {addr_and_ports[i]}") From 92ee51a48eaf37da4ebf7d6996981d50fe248d50 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sat, 11 Apr 2026 09:57:11 -0700 Subject: [PATCH 18/36] update --- .../generation/redesign/fault_tolerance.py | 16 +- .../generation/redesign/sglang_generation.py | 344 ++++++++---------- 2 files changed, 157 insertions(+), 203 deletions(-) diff --git a/nemo_rl/models/generation/redesign/fault_tolerance.py b/nemo_rl/models/generation/redesign/fault_tolerance.py index 0a3e89538e..dba10b75e0 100644 --- a/nemo_rl/models/generation/redesign/fault_tolerance.py +++ b/nemo_rl/models/generation/redesign/fault_tolerance.py @@ -20,8 +20,8 @@ class RolloutHealthMonitor: - stop(): Stop the monitor thread completely (called during dispose) """ - def __init__(self, server_group, sglang_cfg: SGLangConfig): - self._server_group = server_group + def __init__(self, sglang_generation, sglang_cfg: SGLangConfig): + self._sglang_generation = sglang_generation self._thread = None self._stop_event = None @@ -38,7 +38,7 @@ def start(self) -> bool: Returns: True if the monitor was started, False if there are no engines to monitor. """ - if not self._server_group.all_engines: + if not self._sglang_generation.all_engines: return False if self._thread is not None: @@ -135,7 +135,7 @@ def _health_monitor_loop(self) -> None: break def _run_health_checks(self) -> None: - for rollout_engine_id, engine in enumerate(self._server_group.engines): + for rollout_engine_id, engine in enumerate(self._sglang_generation.engines): if self._stop_event is not None and self._stop_event.is_set(): break if self._pause_event is not None and self._pause_event.is_set(): @@ -160,10 +160,10 @@ def _check_engine_health(self, rollout_engine_id, engine) -> None: def _kill_engine(self, rollout_engine_id: int): logger.info(f"Killing server group {rollout_engine_id}...") for i in range( - rollout_engine_id * self._server_group.nodes_per_engine, - (rollout_engine_id + 1) * self._server_group.nodes_per_engine, + rollout_engine_id * self._sglang_generation.nodes_per_engine, + (rollout_engine_id + 1) * self._sglang_generation.nodes_per_engine, ): - engine = self._server_group.all_engines[i] + engine = self._sglang_generation.all_engines[i] if engine: logger.info(f"Shutting down and killing engine at index {i}") try: @@ -174,4 +174,4 @@ def _kill_engine(self, rollout_engine_id: int): logger.warning(f"Fail to kill engine at index {i} (e: {e})") else: logger.info(f"Engine at index {i} is already None") - self._server_group.all_engines[i] = None + self._sglang_generation.all_engines[i] = None diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index e450eeeead..b66b9109af 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -1,5 +1,4 @@ import asyncio -import dataclasses import itertools import logging import multiprocessing @@ -61,55 +60,107 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class ServerGroup: - """A group of homogeneous SGLang engines with the same configuration. +class SGLangGeneration(GenerationInterface): + """The class to run rollout and convert rollout data to training data. - All engines in a group share the same tp_size / nodes_per_engine / pg. - A RolloutServer may contain multiple ServerGroups (e.g. prefill vs decode - in PD disaggregation). + This class owns the full rollout server topology: the placement group, + the router subprocess, and every ``SGLangGenerationWorker`` Ray actor. + The former ``ServerGroup`` dataclass has been folded in so there is a + single source of truth for engine state. """ - pg: Any # (placement_group, reordered_bundle_indices, reordered_gpu_ids) - all_engines: list - num_gpus_per_engine: int - num_gpus_per_node: int - num_new_engines: int - rank_offset: int = 0 - gpu_offset: int = 0 - needs_offload: bool = False - model_path: str | None = None - router_ip: str | None = None - router_port: int | None = None - cluster_cfg: Any = None - sglang_cfg: Any = None - # Router subprocess handle. Only set when _start_router actually spawned - # the router (i.e. sglang_router_ip was not already configured). Kept so - # SGLangGeneration.shutdown() can terminate it cleanly. - router_process: multiprocessing.Process | None = None + def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglang_cfg: SGLangConfig): + self.cluster = cluster + self.cluster_cfg = cluster_cfg + self.sglang_cfg = sglang_cfg + + self.pg = cluster._init_placement_groups( + strategy="PACK", + use_unified_pg=True, + ) + self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids = get_reordered_bundle_and_gpu_ids(self.pg) + + init_http_client(sglang_cfg) + + # --- Engine topology (formerly ``ServerGroup``) ------------------ + gpus_per_engine = sglang_cfg["sglang_server"]["num_gpus_per_engine"] + num_gpus_per_node = cluster_cfg["gpus_per_node"] + num_gpu_per_engine_local = min(gpus_per_engine, num_gpus_per_node) + num_engines = sglang_cfg["sglang_server"]["num_gpus"] // num_gpu_per_engine_local + + self.num_gpus_per_engine: int = gpus_per_engine + self.num_gpus_per_node: int = num_gpus_per_node + self.all_engines: list = [None] * num_engines + self.num_new_engines: int = 0 + self.rank_offset: int = 0 + self.gpu_offset: int = 0 + self.needs_offload: bool = sglang_cfg["sglang_server"]["needs_offload"] + self.model_path: str | None = sglang_cfg["sglang_cfg"]["model_path"] + + # --- Router bootstrap -------------------------------------------- + router_ip, router_port, router_process = _start_router(sglang_cfg) + sglang_cfg["sglang_router"]["sglang_router_ip"] = router_ip + sglang_cfg["sglang_router"]["sglang_router_port"] = router_port + self.router_ip: str = router_ip + self.router_port: int = router_port + # Only set when ``_start_router`` actually spawned the router (i.e. + # sglang_router_ip was not already configured). Kept so ``shutdown`` + # can terminate it cleanly. + self.router_process: multiprocessing.Process | None = router_process + + # --- Start engines ----------------------------------------------- + init_handles, _ = self._start_engines({}) + if init_handles: + ray.get(init_handles) + self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() + + monitor = None + if sglang_cfg["sglang_cfg"].get("use_fault_tolerance"): + monitor = RolloutHealthMonitor(self, sglang_cfg) + monitor.start() + self._health_monitor = monitor + + # ------------------------------------------------------------------ + # Engine topology properties (formerly ``ServerGroup``) + # ------------------------------------------------------------------ @property - def nodes_per_engine(self): + def nodes_per_engine(self) -> int: return max(1, self.num_gpus_per_engine // self.num_gpus_per_node) @property - def engines(self): - """Node-0 engines only (for multi-node serving).""" + def engines(self) -> list: + """Node-0 engines only (one entry per logical engine). + + For multi-node TP, ``all_engines`` contains ``nodes_per_engine`` + consecutive actors per logical engine; this slice returns just the + node-0 representative for each. + """ return self.all_engines[:: self.nodes_per_engine] + @property + def rollout_engines(self) -> list: + """Alias for ``engines`` — node-0 engines across all servers / models.""" + return self.engines + @property def engine_gpu_counts(self) -> list[int]: - """Per-engine GPU count for all node-0 engines, parallel to ``engines``.""" + """Per-engine GPU count, parallel to ``engines``.""" return [self.num_gpus_per_engine for _ in self.engines] @property def engine_gpu_offsets(self) -> list[int]: - offsets = [] - for j in range(len(self.engines)): - offsets.append(self.gpu_offset + j * self.num_gpus_per_engine) - return offsets + return [ + self.gpu_offset + j * self.num_gpus_per_engine + for j in range(len(self.engines)) + ] - def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[list, dict[int, int]]: + # ------------------------------------------------------------------ + # Engine lifecycle (formerly ``ServerGroup.start_engines`` / ``recover``) + # ------------------------------------------------------------------ + def _start_engines( + self, port_cursors: dict[int, int] | None = None + ) -> tuple[list, dict[int, int]]: """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. Returns ``(init_handles, port_cursors)`` where *init_handles* is a list @@ -119,7 +170,9 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis port_cursors = {} num_gpu_per_engine = min(self.num_gpus_per_engine, self.num_gpus_per_node) - pg, reordered_bundle_indices, reordered_gpu_ids = self.pg + pg = self.pg + reordered_bundle_indices = self.pg_reordered_bundle_indices + reordered_gpu_ids = self.pg_reordered_gpu_ids local_all_engines = [] for i in range(len(self.all_engines)): @@ -201,154 +254,105 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis for rank, engine in local_all_engines ] return init_handles, port_cursors - - - def recover(self): - """Recover dead engines across all active groups, overlapping init.""" + + def _recover(self) -> None: + """Recover dead engines, overlapping init.""" dead_indices = [i for i, engine in enumerate(self.all_engines) if engine is None] port_cursors: dict[int, int] = {} - handles = self.start_engines(port_cursors) + handles, _ = self._start_engines(port_cursors) if handles: ray.get(handles) - release_handles = [] - updatable_new_engines = [] + assert self.num_new_engines == len(dead_indices), ( + "num_new_engines does not match dead_indices length" + ) - assert self.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" if self.needs_offload and dead_indices: new_engines = [self.all_engines[i] for i in dead_indices] - release_handles.extend(engine.release_memory_occupation.remote() for engine in new_engines) - updatable_new_engines.extend(new_engines) - - if release_handles: + release_handles = [ + engine.release_memory_occupation.remote() for engine in new_engines + ] ray.get(release_handles) - all_resume_engines = updatable_new_engines[:] - if all_resume_engines: - ray.get( - [ - engine.resume_memory_occupation.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]) - for engine in all_resume_engines - ] - ) - - def offload(self): - if not self.needs_offload: - return [] - return [engine.release_memory_occupation.remote() for engine in self.engines if engine is not None] - - def onload(self, tags: list[str] | None = None): - if not self.needs_offload: - return [] - return [engine.resume_memory_occupation.remote(tags=tags) for engine in self.engines if engine is not None] - - def onload_weights(self): - if not self.needs_offload: - return - handles = self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS]) - return ray.get(handles) if handles else [] - - def onload_kv(self): - handles = self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH]) - return ray.get(handles) if handles else [] - - -class SGLangGeneration(GenerationInterface): - """The class to run rollout and convert rollout data to training data.""" - - def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglang_cfg: SGLangConfig): - self.cluster = cluster - self.cluster_cfg = cluster_cfg - self.sglang_cfg = sglang_cfg - self.pg = cluster._init_placement_groups( - strategy="PACK", - use_unified_pg=True, - ) - self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids = get_reordered_bundle_and_gpu_ids(self.pg) - - init_http_client(sglang_cfg) - self.server_group = start_rollout_servers( - sglang_cfg, - cluster_cfg, - (self.pg, self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids), - ) - - self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() - - monitor = None - if sglang_cfg["sglang_cfg"].get("use_fault_tolerance"): - monitor = RolloutHealthMonitor(self.server_group, sglang_cfg) - monitor.start() - self._health_monitor = monitor - - @property - def rollout_engines(self): - """All node-0 engines across all servers / models.""" - return [e for e in self.server_group.engines] + ray.get( + [ + engine.resume_memory_occupation.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + for engine in new_engines + ] + ) + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ def get_updatable_engines_and_lock(self): """Return engines eligible for weight updates.""" - server_group = self.server_group - engines = server_group.engines if server_group else [] - gpu_counts = server_group.engine_gpu_counts if server_group else [] - gpu_offsets = server_group.engine_gpu_offsets if server_group else [] - num_new = server_group.num_new_engines if server_group else 0 - return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets + return ( + self.engines, + self.rollout_engine_lock, + self.num_new_engines, + self.engine_gpu_counts, + self.engine_gpu_offsets, + ) def offload(self, tags: list[str] | None = None): - """All node-0 engines across all servers / models.""" - if tags is not None: + """Release memory on all node-0 engines across all servers / models.""" + if tags is None and not self.needs_offload: + return [] + if tags is None: handles = [ - engine.release_memory_occupation.remote(tags=tags) - for engine in self.rollout_engines + engine.release_memory_occupation.remote() + for engine in self.engines if engine is not None ] - return ray.get(handles) if handles else [] else: - handles = self.server_group.offload() - return ray.get(handles) if handles else [] + handles = [ + engine.release_memory_occupation.remote(tags=tags) + for engine in self.engines + if engine is not None + ] + return ray.get(handles) if handles else [] def onload(self, tags: list[str] | None = None): - handles = self.server_group.onload(tags) + if not self.needs_offload: + return [] + handles = [ + engine.resume_memory_occupation.remote(tags=tags) + for engine in self.engines + if engine is not None + ] return ray.get(handles) if handles else [] def onload_weights(self): - self.server_group.onload_weights() + if not self.needs_offload: + return + self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS]) def onload_kv(self): - self.server_group.onload_kv() + self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH]) def recover_updatable_engines(self): - """Restart any dead rollout engines and update num_new_engines for update_weights detection. - - Recovers the updatable model (the one that receives weight - updates from training). + """Restart any dead rollout engines and update ``num_new_engines`` + for weight-update detection. """ - server_group = self.server_group - - if server_group is None: - return [], self.rollout_engine_lock, 0, [], [] - - server_group.recover() + self._recover() return ( - server_group.engines, + self.engines, self.rollout_engine_lock, - server_group.num_new_engines, - server_group.engine_gpu_counts, - server_group.engine_gpu_offsets, + self.num_new_engines, + self.engine_gpu_counts, + self.engine_gpu_offsets, ) def clear_updatable_num_new_engines(self): # when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights - if self.server_group: - self.server_group.num_new_engines = 0 + self.num_new_engines = 0 def check_weights(self, action: str): """All node-0 engines across all servers / models.""" return ray.get( [ engine.check_weights.remote(action=action) - for engine in self.rollout_engines + for engine in self.engines if engine is not None ] ) @@ -366,7 +370,7 @@ def shutdown(self) -> bool: self._health_monitor.stop() ok = True - engines = [e for e in self.server_group.all_engines if e is not None] + engines = [e for e in self.all_engines if e is not None] if engines: try: ray.get([e.shutdown.remote() for e in engines]) @@ -374,14 +378,13 @@ def shutdown(self) -> bool: logger.warning(f"Engine shutdown failed: {e}") ok = False - router_process = self.server_group.router_process - if router_process is not None: + if self.router_process is not None: try: - terminate_process(router_process, timeout=3.0) + terminate_process(self.router_process, timeout=3.0) except Exception as e: logger.warning(f"Router terminate failed: {e}") ok = False - self.server_group.router_process = None + self.router_process = None return ok @@ -955,52 +958,3 @@ def _start_router( assert process.is_alive() logger.info(f"Router launched at {router_ip}:{router_port}") return router_ip, router_port, process - -def start_rollout_servers(sglang_cfg, cluster_cfg, pg) -> ServerGroup: - """Start rollout servers: one per model, each with its own router. - - Returns a dict mapping model name -> ``RolloutServer``. - """ - - engine_offset = 0 - gpu_offset = 0 - - router_ip, router_port, router_process = _start_router(sglang_cfg) - - sglang_cfg["sglang_router"]["sglang_router_ip"] = router_ip - sglang_cfg["sglang_router"]["sglang_router_port"] = router_port - - all_init_handles: list = [] - port_cursors: dict[int, int] = {} - - gpus_per_engine = sglang_cfg["sglang_server"]["num_gpus_per_engine"] - num_gpu_per_engine_local = min(gpus_per_engine, cluster_cfg["gpus_per_node"]) - num_engines = sglang_cfg["sglang_server"]["num_gpus"] // num_gpu_per_engine_local - needs_offload = sglang_cfg["sglang_server"]["needs_offload"] - num_gpus_per_node = cluster_cfg["gpus_per_node"] - model_path= sglang_cfg["sglang_cfg"]["model_path"] - - server_group = ServerGroup( - pg=pg, - all_engines=[None] * num_engines, - num_gpus_per_engine=gpus_per_engine, - num_gpus_per_node=num_gpus_per_node, - num_new_engines=0, - rank_offset=engine_offset, - gpu_offset=gpu_offset, - needs_offload=needs_offload, - model_path=model_path, - router_ip=router_ip, - router_port=router_port, - cluster_cfg=cluster_cfg, - sglang_cfg=sglang_cfg, - router_process=router_process, - ) - - handles, port_cursors = server_group.start_engines(port_cursors) - all_init_handles.extend(handles) - - if all_init_handles: - ray.get(all_init_handles) - - return server_group From 64338059106f08f55084bcac269b08b80a02bb89 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sat, 11 Apr 2026 10:55:37 -0700 Subject: [PATCH 19/36] update --- .../generation/redesign/sglang_generation.py | 34 +++++++++++++++++++ .../generation/redesign/sglang_worker.py | 15 ++++++++ 2 files changed, 49 insertions(+) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index b66b9109af..62e39ee1f0 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -788,6 +788,40 @@ async def generate_async( yield (sample_idx, result_batch) +# --------------------------------------------------------------------------- +# Compatible with parent class or old interfaces +# --------------------------------------------------------------------------- + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> list[ray.ObjectRef]: + return [] + + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: + pass + + def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: + return [] + + def update_weights_from_collective(self) -> list[ray.ObjectRef]: + return [] + + def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: + """Wake workers up for colocated inference.""" + pass + + def finish_generation(self, *args: Any, **kwargs: Any) -> bool: + """Sleep workers and reset prefix cache.""" + pass + + def invalidate_kv_cache(self) -> bool: + pass + + def get_sglang_server_urls(self) -> list[str]: + pass + + def get_sglang_url_to_gpu_uuids(self) -> dict[str, list[str]]: + pass + # --------------------------------------------------------------------------- # Generate one sample helper # --------------------------------------------------------------------------- diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 77786b7abc..1efc384dca 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -460,6 +460,21 @@ def simulate_crash(self): logger.info(f"Simulating crash on engine {self.server_host}:{self.server_port}...") self.shutdown() +# --------------------------------------------------------------------------- +# Compatible with parent class or old interfaces +# --------------------------------------------------------------------------- + def get_base_url(self) -> str | None: + pass + + def get_gpu_uuids(self) -> list[str]: + pass + + def invalidate_kv_cache(self) -> bool: + pass + +# ---------------------------------------------------------------------------- +# Compute Server args +# ---------------------------------------------------------------------------- def _compute_server_args( cluster_cfg, sglang_cfg, From 646cca7b593b244bb90c0847287d112490bef165 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 12 Apr 2026 00:22:35 -0700 Subject: [PATCH 20/36] update --- .../ray_actor_environment_registry.py | 1 + .../generation/redesign/sglang_actors.py | 87 ++++++ .../generation/redesign/sglang_generation.py | 294 +++++++++++------- .../generation/redesign/sglang_worker.py | 89 +++++- 4 files changed, 336 insertions(+), 135 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/sglang_actors.py diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 95677873a4..9a5b79561e 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -31,6 +31,7 @@ "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE, "nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE, "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE, + "nemo_rl.models.generation.redesign.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE, "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.FSDP, "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL, "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE, diff --git a/nemo_rl/models/generation/redesign/sglang_actors.py b/nemo_rl/models/generation/redesign/sglang_actors.py new file mode 100644 index 0000000000..0bd841c047 --- /dev/null +++ b/nemo_rl/models/generation/redesign/sglang_actors.py @@ -0,0 +1,87 @@ +import importlib +import logging + +import ray + +logger = logging.getLogger(__name__) + +SGLANG_WORKER_FQN = "nemo_rl.models.generation.redesign.sglang_worker.SGLangGenerationWorker" + + +@ray.remote +class SGLangWorkerInitializer: + """Loads and constructs SGLangGenerationWorker inside the sglang env. + + Mirrors the role of RayWorkerBuilder.IsolatedWorkerInitializer. + We spawn this actor with runtime_env={"py_executable": SGLANG_EXECUTABLE}, + then call importlib.import_module on the worker module from *inside* this + actor — which means the worker module's top-level sglang imports execute in + a process that actually has sglang installed. + """ + + def __init__(self, fqn: str): + self._fqn = fqn + + def create(self, actor_options: dict, init_args: tuple, init_kwargs: dict): + module_name, class_name = self._fqn.rsplit(".", 1) + module = importlib.import_module(module_name) + worker_class = getattr(module, class_name) + return worker_class.options(**actor_options).remote(*init_args, **init_kwargs) + + +@ray.remote(num_cpus=1, num_gpus=0) +class RouterActor: + """Starts and owns the sglang router subprocess. + + Runs under SGLANG_EXECUTABLE so it can import sglang_router. + The driver (SYSTEM env) holds a handle to this actor and retrieves + (router_ip, router_port) without ever importing sglang_router itself. + """ + + def start(self, router_cfg: dict) -> tuple[str, int]: + import multiprocessing + import random + + from sglang_router.launch_router import RouterArgs + + from nemo_rl.models.generation.redesign.misc import run_router + from nemo_rl.models.generation.redesign.ray_utils import ( + _wrap_ipv6, + find_available_port, + get_host_info, + ) + + router_ip = _wrap_ipv6(get_host_info()[1]) + router_port = router_cfg.get("sglang_router_port") + if router_port is None: + router_port = find_available_port(random.randint(3000, 4000)) + + router_args = RouterArgs() + router_args.host = router_ip + router_args.port = router_port + if router_cfg.get("router_policy") is not None: + router_args.router_policy = router_cfg["router_policy"] + router_args.prometheus_port = find_available_port( + random.randint(4000, 5000) + ) + router_args.log_level = "warn" + request_timeout_secs = router_cfg.get( + "sglang_router_request_timeout_secs" + ) + if request_timeout_secs is not None: + router_args.request_timeout_secs = request_timeout_secs + + self._process = multiprocessing.Process( + target=run_router, args=(router_args,) + ) + self._process.daemon = True + self._process.start() + import time + time.sleep(3) + assert self._process.is_alive(), "Router process died on startup" + return router_ip, router_port + + def stop(self): + from nemo_rl.models.generation.redesign.misc import terminate_process + if hasattr(self, "_process"): + terminate_process(self._process) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 62e39ee1f0..488073fd2f 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -1,11 +1,7 @@ import asyncio -import itertools import logging -import multiprocessing import os -import random import threading -import time from pathlib import Path from typing import Any, AsyncGenerator @@ -13,11 +9,6 @@ import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from sglang.srt.constants import ( - GPU_MEMORY_TYPE_CUDA_GRAPH, - GPU_MEMORY_TYPE_KV_CACHE, - GPU_MEMORY_TYPE_WEIGHTS, -) from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import ( @@ -38,19 +29,14 @@ init_http_client, post, ) -from nemo_rl.models.generation.redesign.misc import ( - NOSET_VISIBLE_DEVICES_ENV_VARS_LIST, - run_router, - terminate_process, -) -from nemo_rl.models.generation.redesign.ray_utils import ( - Lock, - _wrap_ipv6, - find_available_port, - get_host_info, +from nemo_rl.distributed.ray_actor_environment_registry import SGLANG_EXECUTABLE +from nemo_rl.models.generation.redesign.misc import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST +from nemo_rl.models.generation.redesign.ray_utils import Lock +from nemo_rl.models.generation.redesign.sglang_actors import ( + SGLANG_WORKER_FQN, + RouterActor, + SGLangWorkerInitializer, ) -from nemo_rl.models.generation.redesign.sglang_worker import SGLangGenerationWorker - from nemo_rl.models.generation.redesign.fault_tolerance import RolloutHealthMonitor from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -98,7 +84,7 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan self.model_path: str | None = sglang_cfg["sglang_cfg"]["model_path"] # --- Router bootstrap -------------------------------------------- - router_ip, router_port, router_process = _start_router(sglang_cfg) + router_ip, router_port, router_actor = _start_router(sglang_cfg) sglang_cfg["sglang_router"]["sglang_router_ip"] = router_ip sglang_cfg["sglang_router"]["sglang_router_port"] = router_port self.router_ip: str = router_ip @@ -106,7 +92,7 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan # Only set when ``_start_router`` actually spawned the router (i.e. # sglang_router_ip was not already configured). Kept so ``shutdown`` # can terminate it cleanly. - self.router_process: multiprocessing.Process | None = router_process + self._router_actor: ray.actor.ActorHandle | None = router_actor # --- Start engines ----------------------------------------------- init_handles, _ = self._start_engines({}) @@ -174,7 +160,12 @@ def _start_engines( reordered_bundle_indices = self.pg_reordered_bundle_indices reordered_gpu_ids = self.pg_reordered_gpu_ids - local_all_engines = [] + # One initializer per _start_engines() call (not per engine). + initializer = SGLangWorkerInitializer.options( + runtime_env={"py_executable": SGLANG_EXECUTABLE}, + ).remote(SGLANG_WORKER_FQN) + + engine_refs: list[tuple[int, int, ray.ObjectRef]] = [] # (index, rank, ref) for i in range(len(self.all_engines)): if self.all_engines[i] is not None: continue @@ -212,30 +203,40 @@ def _start_engines( if global_cvd: env_vars["CUDA_VISIBLE_DEVICES"] = global_cvd - rollout_engine = SGLangGenerationWorker.options( - num_cpus=num_cpus, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - runtime_env={ + actor_options = { + "num_cpus": num_cpus, + "num_gpus": num_gpus, + "scheduling_strategy": scheduling_strategy, + "runtime_env": { + "py_executable": SGLANG_EXECUTABLE, "env_vars": env_vars, **get_nsight_config_if_pattern_matches("sglang_generation_worker"), }, - ).remote( - self.cluster_cfg, - self.sglang_cfg, - rank=global_rank, - base_gpu_id=base_gpu_id, - num_gpus_per_engine=self.num_gpus_per_engine, - ) - - local_all_engines.append((global_rank, rollout_engine)) - self.all_engines[i] = rollout_engine + } + init_args = (self.cluster_cfg, self.sglang_cfg) + init_kwargs = { + "rank": global_rank, + "base_gpu_id": base_gpu_id, + "num_gpus_per_engine": self.num_gpus_per_engine, + } - self.num_new_engines = len(local_all_engines) + # Collect refs — do NOT ray.get inside the loop (preserve parallel creation). + rollout_engine_ref = initializer.create.remote(actor_options, init_args, init_kwargs) + engine_refs.append((i, global_rank, rollout_engine_ref)) - if self.num_new_engines == 0: + # Resolve all engine actor handles in parallel. + if not engine_refs: + self.num_new_engines = 0 return [], port_cursors + resolved_engines = ray.get([ref for _, _, ref in engine_refs]) + local_all_engines = [] + for (i, global_rank, _), engine in zip(engine_refs, resolved_engines): + self.all_engines[i] = engine + local_all_engines.append((global_rank, engine)) + + self.num_new_engines = len(local_all_engines) + base_port = max(port_cursors.values()) if port_cursors else 15000 addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( cluster_cfg=self.cluster_cfg, @@ -270,16 +271,8 @@ def _recover(self) -> None: if self.needs_offload and dead_indices: new_engines = [self.all_engines[i] for i in dead_indices] - release_handles = [ - engine.release_memory_occupation.remote() for engine in new_engines - ] - ray.get(release_handles) - ray.get( - [ - engine.resume_memory_occupation.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]) - for engine in new_engines - ] - ) + ray.get([engine.release_memory_weights.remote() for engine in new_engines]) + ray.get([engine.resume_memory_weights.remote() for engine in new_engines]) # ------------------------------------------------------------------ # Public API @@ -294,41 +287,45 @@ def get_updatable_engines_and_lock(self): self.engine_gpu_offsets, ) - def offload(self, tags: list[str] | None = None): - """Release memory on all node-0 engines across all servers / models.""" - if tags is None and not self.needs_offload: - return [] - if tags is None: - handles = [ - engine.release_memory_occupation.remote() - for engine in self.engines - if engine is not None - ] - else: - handles = [ - engine.release_memory_occupation.remote(tags=tags) - for engine in self.engines - if engine is not None - ] - return ray.get(handles) if handles else [] - - def onload(self, tags: list[str] | None = None): + def offload_weights(self): if not self.needs_offload: - return [] + return + handles = [ + engine.release_memory_weights.remote() + for engine in self.engines + if engine is not None + ] + if handles: + ray.get(handles) + + def offload_kv(self): handles = [ - engine.resume_memory_occupation.remote(tags=tags) + engine.release_memory_kv_cache_and_cuda_graph.remote() for engine in self.engines if engine is not None ] - return ray.get(handles) if handles else [] + if handles: + ray.get(handles) def onload_weights(self): if not self.needs_offload: return - self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + handles = [ + engine.resume_memory_weights.remote() + for engine in self.engines + if engine is not None + ] + if handles: + ray.get(handles) def onload_kv(self): - self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH]) + handles = [ + engine.resume_memory_kv_cache_and_cuda_graph.remote() + for engine in self.engines + if engine is not None + ] + if handles: + ray.get(handles) def recover_updatable_engines(self): """Restart any dead rollout engines and update ``num_new_engines`` @@ -378,13 +375,14 @@ def shutdown(self) -> bool: logger.warning(f"Engine shutdown failed: {e}") ok = False - if self.router_process is not None: + if self._router_actor is not None: try: - terminate_process(self.router_process, timeout=3.0) + ray.get(self._router_actor.stop.remote()) + ray.kill(self._router_actor) except Exception as e: logger.warning(f"Router terminate failed: {e}") ok = False - self.router_process = None + self._router_actor = None return ok @@ -814,13 +812,100 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool: pass def invalidate_kv_cache(self) -> bool: - pass + """Invalidate KV cache before weight updates (Megatron-style). + + Flushes the cache on every node-0 engine so stale KV entries are + discarded before new weights land. Returns ``True`` iff every engine + reports success. + """ + engines = [e for e in self.engines if e is not None] + if not engines: + return True + try: + results = ray.get([e.invalidate_kv_cache.remote() for e in engines]) + except Exception as e: + logger.error(f"[sglang refit] Error flushing SGLang caches: {e}") + return False + + success = all(results) + if success: + logger.info("[sglang refit] All SGLang server caches flushed successfully") + else: + logger.warning( + "[sglang refit] WARNING - Some SGLang server caches failed to flush" + ) + return success def get_sglang_server_urls(self) -> list[str]: - pass + """Return the base URLs of all SGLang servers (one per logical engine). + + Returns: + List of base URLs, e.g. ``["http://host-a:30000", "http://host-b:30001"]``. + """ + engines = [e for e in self.engines if e is not None] + if not engines: + raise RuntimeError("No rollout engines initialized") + + urls = ray.get([e.get_base_url.remote() for e in engines]) + return list({u for u in urls if u is not None}) def get_sglang_url_to_gpu_uuids(self) -> dict[str, list[str]]: - pass + """Return a mapping of SGLang server URL to the GPU UUIDs it owns. + + For multi-node TP, a single logical engine spans ``nodes_per_engine`` + consecutive actors in ``all_engines``; we key by the node-0 URL and + concatenate the local-node UUID slices from every peer in the group. + + Returns: + Dict mapping base URL to list of GPU UUIDs, + e.g. ``{"http://host-a:30000": ["GPU-aaa", "GPU-bbb"], ...}``. + """ + if not any(e is not None for e in self.all_engines): + raise RuntimeError("No rollout engines initialized") + + nodes_per_engine = self.nodes_per_engine + + # Fan out to every actor (not just node-0) so peer ranks can report + # their own GPU slice. ``None`` slots mean the engine was killed by + # the health monitor and not yet recovered; skip them. + url_refs = [ + e.get_base_url.remote() if e is not None else None + for e in self.all_engines + ] + uuid_refs = [ + e.get_gpu_uuids.remote() if e is not None else None + for e in self.all_engines + ] + urls_live = ray.get([r for r in url_refs if r is not None]) + uuids_live = ray.get([r for r in uuid_refs if r is not None]) + + # Re-key results back to their original slot index so gaps (None + # engines) stay aligned with ``all_engines``. + urls_by_slot: list[str | None] = [] + uuids_by_slot: list[list[str] | None] = [] + u_iter = iter(urls_live) + g_iter = iter(uuids_live) + for engine in self.all_engines: + if engine is None: + urls_by_slot.append(None) + uuids_by_slot.append(None) + else: + urls_by_slot.append(next(u_iter)) + uuids_by_slot.append(next(g_iter)) + + url_to_uuids: dict[str, list[str]] = {} + for group_start in range(0, len(self.all_engines), nodes_per_engine): + node0_url = urls_by_slot[group_start] + if node0_url is None: + continue + aggregated: list[str] = [] + for i in range(group_start, group_start + nodes_per_engine): + slot_uuids = uuids_by_slot[i] + if slot_uuids: + aggregated.extend(slot_uuids) + if aggregated: + url_to_uuids[node0_url] = aggregated + return url_to_uuids # --------------------------------------------------------------------------- # Generate one sample helper @@ -944,51 +1029,22 @@ def addr(): return addr_and_ports, node_port_cursor -# --------------------------------------------------------------------------- -# Router + server bootstrap -# --------------------------------------------------------------------------- - def _start_router( sglang_cfg: SGLangConfig, -) -> tuple[str, int, multiprocessing.Process | None]: - """Start sgl router, returning ``(router_ip, router_port, process)``. +) -> tuple[str, int, ray.actor.ActorHandle | None]: + """Start sgl router, returning ``(router_ip, router_port, actor_handle)``. If ``sglang_router.sglang_router_ip`` is already set, reuse it and return - ``process=None`` (we do not own that router and must not terminate it). - Otherwise spawn a new router process and return its handle so the caller - can shut it down explicitly. + ``actor_handle=None`` (we do not own that router and must not terminate it). + Otherwise spawn a ``RouterActor`` in sglang env to own the router process. """ router_cfg = sglang_cfg["sglang_router"] if router_cfg["sglang_router_ip"] is not None: return router_cfg["sglang_router_ip"], router_cfg["sglang_router_port"], None - router_ip = _wrap_ipv6(get_host_info()[1]) - router_port = router_cfg["sglang_router_port"] - if router_port is None: - router_port = find_available_port(random.randint(3000, 4000)) - - from sglang_router.launch_router import RouterArgs - - router_args = RouterArgs() - router_args.host = router_ip - router_args.port = router_port - if router_cfg.get("router_policy") is not None: - router_args.router_policy = router_cfg["router_policy"] - router_args.prometheus_port = find_available_port(random.randint(4000, 5000)) - router_args.log_level = "warn" - request_timeout_secs = router_cfg.get("sglang_router_request_timeout_secs") - if request_timeout_secs is not None: - router_args.request_timeout_secs = request_timeout_secs - - logger.info(f"Launch router with args: {router_args}") - - process = multiprocessing.Process( - target=run_router, - args=(router_args,), - ) - process.daemon = True - process.start() - time.sleep(3) - assert process.is_alive() + router_actor = RouterActor.options( + runtime_env={"py_executable": SGLANG_EXECUTABLE}, + ).remote() + router_ip, router_port = ray.get(router_actor.start.remote(dict(router_cfg))) logger.info(f"Router launched at {router_ip}:{router_port}") - return router_ip, router_port, process + return router_ip, router_port, router_actor diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 1efc384dca..50af9fdbc5 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -10,6 +10,11 @@ import requests import sglang_router from packaging.version import parse +from sglang.srt.constants import ( + GPU_MEMORY_TYPE_CUDA_GRAPH, + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, +) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError @@ -209,9 +214,6 @@ def _make_request(self, endpoint: str, payload: dict | None = None): def _get_current_node_ip_and_free_port(start_port=10000, consecutive=1): return get_current_node_ip(), get_free_port(start_port=start_port, consecutive=consecutive) - def get_master_addr_and_port(self): - return self.master_addr, self.master_port - def health_generate(self, timeout: float = 5.0) -> bool: """Run /health_generate on the underlying SGLang HTTP server. @@ -337,19 +339,25 @@ def resume_memory_occupation(self, tags: list[str] = None): {"tags": tags}, ) + def release_memory_weights(self): + return self.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + + def release_memory_kv_cache_and_cuda_graph(self): + return self.release_memory_occupation( + tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] + ) + + def resume_memory_weights(self): + return self.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) + + def resume_memory_kv_cache_and_cuda_graph(self): + return self.resume_memory_occupation( + tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] + ) + def check_weights(self, action: str): return self._make_request("weights_checker", {"action": action}) - def update_weights_from_disk(self, model_path: str, load_format: str | None = None): - """Reload weights from *model_path* without restarting the engine. - - Used for non-updatable (frozen) models that overlap with megatron: - after offload, weights are restored from disk instead of CPU cache. - """ - payload = {"model_path": model_path} - if load_format is not None: - payload["load_format"] = load_format - return self._make_request("update_weights_from_disk", payload) def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): return self._make_request( @@ -464,13 +472,62 @@ def simulate_crash(self): # Compatible with parent class or old interfaces # --------------------------------------------------------------------------- def get_base_url(self) -> str | None: - pass + """Return the ``http://host:port`` base URL of this SGLang server. + + Only node-rank 0 owns the HTTP server; peer ranks return ``None`` so + callers can filter them out when collecting per-engine URLs. + """ + if self.node_rank != 0: + return None + return f"http://{self.server_host}:{self.server_port}" def get_gpu_uuids(self) -> list[str]: - pass + """Return the GPU UUIDs this actor owns on its local node. + + SGLang lays out GPUs contiguously starting at ``base_gpu_id``. Every + rank (including peer nodes in multi-node TP) reports its own + local-node slice of ``min(num_gpus_per_engine, gpus_per_node)`` GPUs; + the orchestrator concatenates the slices across peers to rebuild the + full UUID list for a logical engine. + """ + from nemo_rl.utils.nvml import get_device_uuid + + num_local_gpus = min( + self.num_gpus_per_engine, + self.cluster_cfg["gpus_per_node"], + ) + # ``self.base_gpu_id`` stores the *physical* GPU id handed down by + # the orchestrator, but ``get_device_uuid`` indexes into + # ``CUDA_VISIBLE_DEVICES`` and therefore expects a *local* id — so + # remap before calling it. + local_base = _to_local_gpu_id(self.base_gpu_id) + return [get_device_uuid(local_base + i) for i in range(num_local_gpus)] def invalidate_kv_cache(self) -> bool: - pass + """Flush the cache of the server. + + Returns: + True on a successful flush, False on timeout / error. Peer + (non-node-0) ranks return True since they do not own the HTTP + server. + """ + if self.node_rank != 0: + return True + # flush cache will not return status_code 200 when there are pending requests + for _ in range(60): + try: + response = requests.get(f"http://{self.server_host}:{self.server_port}/flush_cache") + if response.status_code == 200: + return True + except NewConnectionError as e: + logger.error(f"Connection error flushing cache: {e}") + return False + except Exception as e: + logger.info(f"Error flushing cache: {e}") + time.sleep(1) + continue + logger.error("Timeout while flushing cache.") + return False # ---------------------------------------------------------------------------- # Compute Server args From 2e55a36741d708d420b1e63eb67e1d62f1d84fe9 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 12 Apr 2026 00:46:11 -0700 Subject: [PATCH 21/36] update --- nemo_rl/algorithms/grpo.py | 17 +++++++++++++++-- nemo_rl/models/policy/utils.py | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e550429ce2..447f4f4ebb 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -68,7 +68,8 @@ run_multi_turn_rollout, ) from nemo_rl.models.generation.interfaces import GenerationInterface -from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration +from nemo_rl.models.generation.redesign.config import SGLangConfig +from nemo_rl.models.generation.redesign.sglang_generation import SGLangGeneration from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface @@ -438,6 +439,10 @@ def init_train_dataloader(dataset, suffix: str = ""): ) train_cluster = cluster inference_cluster = cluster + inference_cluster_cfg: ClusterConfig = { + "gpus_per_node": policy_gpus_per_node, + "num_nodes": policy_nodes, + } print( f" ✓ Ray cluster for policy initialized with {policy_nodes} nodes", flush=True, @@ -526,6 +531,10 @@ def init_train_dataloader(dataset, suffix: str = ""): num_gpus_per_node=inference_gpus_per_node, max_colocated_worker_groups=1, ) + inference_cluster_cfg: ClusterConfig = { + "gpus_per_node": inference_gpus_per_node, + "num_nodes": inference_nodes, + } print( f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node", flush=True, @@ -578,7 +587,11 @@ def init_vllm(): def init_sglang(): """Initialize SGLang generation workers.""" t0 = time.perf_counter() - pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) + pg = SGLangGeneration( + cluster=inference_cluster, + cluster_cfg=inference_cluster_cfg, + sglang_cfg=generation_config, + ) pg.finish_generation() return pg, time.perf_counter() - t0 diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index bbd2e6d2f6..02eb9e6ff8 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -405,7 +405,7 @@ def stream_weights_via_http_impl( worker_name: Name of the worker for logging current_device_uuid: UUID of the current training worker's GPU """ - from nemo_rl.models.generation.sglang.sglang_copied_utils import ( + from nemo_rl.models.generation.redesign.misc import ( MultiprocessingSerializer, ) From d461bf1da2f6f6be8a2da095394603134fc3f1ae Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 12 Apr 2026 01:23:01 -0700 Subject: [PATCH 22/36] update --- .../generation/redesign/sglang_generation.py | 21 +++++++------------ .../generation/redesign/sglang_worker.py | 20 ++++++++++-------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 488073fd2f..9add90358d 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -32,11 +32,8 @@ from nemo_rl.distributed.ray_actor_environment_registry import SGLANG_EXECUTABLE from nemo_rl.models.generation.redesign.misc import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST from nemo_rl.models.generation.redesign.ray_utils import Lock -from nemo_rl.models.generation.redesign.sglang_actors import ( - SGLANG_WORKER_FQN, - RouterActor, - SGLangWorkerInitializer, -) +from nemo_rl.models.generation.redesign.sglang_actors import RouterActor +from nemo_rl.models.generation.redesign.sglang_worker import SGLangGenerationWorker from nemo_rl.models.generation.redesign.fault_tolerance import RolloutHealthMonitor from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -160,11 +157,6 @@ def _start_engines( reordered_bundle_indices = self.pg_reordered_bundle_indices reordered_gpu_ids = self.pg_reordered_gpu_ids - # One initializer per _start_engines() call (not per engine). - initializer = SGLangWorkerInitializer.options( - runtime_env={"py_executable": SGLANG_EXECUTABLE}, - ).remote(SGLANG_WORKER_FQN) - engine_refs: list[tuple[int, int, ray.ObjectRef]] = [] # (index, rank, ref) for i in range(len(self.all_engines)): if self.all_engines[i] is not None: @@ -220,9 +212,12 @@ def _start_engines( "num_gpus_per_engine": self.num_gpus_per_engine, } - # Collect refs — do NOT ray.get inside the loop (preserve parallel creation). - rollout_engine_ref = initializer.create.remote(actor_options, init_args, init_kwargs) - engine_refs.append((i, global_rank, rollout_engine_ref)) + # Create worker actor directly — sglang_worker.py uses lazy imports + # so it's importable in SYSTEM env; the actor runs in sglang env. + engine = SGLangGenerationWorker.options(**actor_options).remote( + *init_args, **init_kwargs + ) + engine_refs.append((i, global_rank, engine)) # Resolve all engine actor handles in parallel. if not engine_refs: diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 50af9fdbc5..c28fa55d88 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -8,15 +8,7 @@ import ray import requests -import sglang_router from packaging.version import parse -from sglang.srt.constants import ( - GPU_MEMORY_TYPE_CUDA_GRAPH, - GPU_MEMORY_TYPE_KV_CACHE, - GPU_MEMORY_TYPE_WEIGHTS, -) -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError from nemo_rl.models.generation.redesign.ray_utils import ( @@ -49,7 +41,7 @@ def _to_local_gpu_id(physical_gpu_id: int) -> int: f"Expected one of {visible} (physical) or 0..{len(visible)-1} (local)." ) -def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: +def launch_server_process(server_args) -> multiprocessing.Process: from sglang.srt.entrypoints.http_server import launch_server multiprocessing.set_start_method("spawn", force=True) @@ -169,6 +161,9 @@ def _format_v6_uri(addr): def _init_normal(self, server_args_dict): + import sglang_router + from sglang.srt.server_args import ServerArgs + logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") self.process = launch_server_process(ServerArgs(**server_args_dict)) @@ -281,6 +276,9 @@ def flush_cache(self): raise TimeoutError("Timeout while flushing cache.") def shutdown(self): + import sglang_router + from sglang.srt.utils import kill_process_tree + logger.info(f"Shutdown engine {self.server_host}:{self.server_port}...") if self.node_rank == 0: worker_url = f"http://{self.server_host}:{self.server_port}" @@ -340,17 +338,21 @@ def resume_memory_occupation(self, tags: list[str] = None): ) def release_memory_weights(self): + from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS return self.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) def release_memory_kv_cache_and_cuda_graph(self): + from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE return self.release_memory_occupation( tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] ) def resume_memory_weights(self): + from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS return self.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) def resume_memory_kv_cache_and_cuda_graph(self): + from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE return self.resume_memory_occupation( tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] ) From 752d2c63ae8528047ec8f287bb5d870590e7266b Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 12 Apr 2026 01:27:14 -0700 Subject: [PATCH 23/36] update --- .../generation/redesign/sglang_generation.py | 2 +- .../{sglang_actors.py => sglang_router.py} | 22 ------------------- 2 files changed, 1 insertion(+), 23 deletions(-) rename nemo_rl/models/generation/redesign/{sglang_actors.py => sglang_router.py} (71%) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 9add90358d..58ca55054b 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -32,7 +32,7 @@ from nemo_rl.distributed.ray_actor_environment_registry import SGLANG_EXECUTABLE from nemo_rl.models.generation.redesign.misc import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST from nemo_rl.models.generation.redesign.ray_utils import Lock -from nemo_rl.models.generation.redesign.sglang_actors import RouterActor +from nemo_rl.models.generation.redesign.sglang_router import RouterActor from nemo_rl.models.generation.redesign.sglang_worker import SGLangGenerationWorker from nemo_rl.models.generation.redesign.fault_tolerance import RolloutHealthMonitor diff --git a/nemo_rl/models/generation/redesign/sglang_actors.py b/nemo_rl/models/generation/redesign/sglang_router.py similarity index 71% rename from nemo_rl/models/generation/redesign/sglang_actors.py rename to nemo_rl/models/generation/redesign/sglang_router.py index 0bd841c047..03bf96aff3 100644 --- a/nemo_rl/models/generation/redesign/sglang_actors.py +++ b/nemo_rl/models/generation/redesign/sglang_router.py @@ -7,28 +7,6 @@ SGLANG_WORKER_FQN = "nemo_rl.models.generation.redesign.sglang_worker.SGLangGenerationWorker" - -@ray.remote -class SGLangWorkerInitializer: - """Loads and constructs SGLangGenerationWorker inside the sglang env. - - Mirrors the role of RayWorkerBuilder.IsolatedWorkerInitializer. - We spawn this actor with runtime_env={"py_executable": SGLANG_EXECUTABLE}, - then call importlib.import_module on the worker module from *inside* this - actor — which means the worker module's top-level sglang imports execute in - a process that actually has sglang installed. - """ - - def __init__(self, fqn: str): - self._fqn = fqn - - def create(self, actor_options: dict, init_args: tuple, init_kwargs: dict): - module_name, class_name = self._fqn.rsplit(".", 1) - module = importlib.import_module(module_name) - worker_class = getattr(module, class_name) - return worker_class.options(**actor_options).remote(*init_args, **init_kwargs) - - @ray.remote(num_cpus=1, num_gpus=0) class RouterActor: """Starts and owns the sglang router subprocess. From 3c6c419694d5bb8227740657ea24f972f1be66b9 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Sun, 12 Apr 2026 15:51:46 -0700 Subject: [PATCH 24/36] update --- .../models/generation/redesign/sglang_worker.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index c28fa55d88..ee07040996 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -357,6 +357,22 @@ def resume_memory_kv_cache_and_cuda_graph(self): tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] ) + def get_weights_by_name(self, name: str, truncate_size: int = 0): + """Get a model parameter by name from the SGLang server. + + Args: + name: Fully qualified parameter name (e.g. "model.layers.0.self_attn.q_proj.weight"). + truncate_size: Maximum number of elements to return per tensor. + 0 means return the full tensor. + + Returns: + Server response JSON containing the parameter values. + """ + return self._make_request( + "get_weights_by_name", + {"name": name, "truncate_size": truncate_size}, + ) + def check_weights(self, action: str): return self._make_request("weights_checker", {"action": action}) From 1ab9dc491f6ed0ca829c42dbdfba82d89ce33738 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Mon, 13 Apr 2026 15:48:05 -0700 Subject: [PATCH 25/36] update --- .../generation/redesign/sglang_generation.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 58ca55054b..8a07d0c336 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -57,10 +57,11 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan self.cluster_cfg = cluster_cfg self.sglang_cfg = sglang_cfg - self.pg = cluster._init_placement_groups( + pgs = cluster._init_placement_groups( strategy="PACK", use_unified_pg=True, ) + self.pg = pgs[0] self.pg_reordered_bundle_indices, self.pg_reordered_gpu_ids = get_reordered_bundle_and_gpu_ids(self.pg) init_http_client(sglang_cfg) @@ -157,7 +158,7 @@ def _start_engines( reordered_bundle_indices = self.pg_reordered_bundle_indices reordered_gpu_ids = self.pg_reordered_gpu_ids - engine_refs: list[tuple[int, int, ray.ObjectRef]] = [] # (index, rank, ref) + local_all_engines = [] for i in range(len(self.all_engines)): if self.all_engines[i] is not None: continue @@ -217,21 +218,15 @@ def _start_engines( engine = SGLangGenerationWorker.options(**actor_options).remote( *init_args, **init_kwargs ) - engine_refs.append((i, global_rank, engine)) - - # Resolve all engine actor handles in parallel. - if not engine_refs: - self.num_new_engines = 0 - return [], port_cursors - resolved_engines = ray.get([ref for _, _, ref in engine_refs]) - local_all_engines = [] - for (i, global_rank, _), engine in zip(engine_refs, resolved_engines): - self.all_engines[i] = engine local_all_engines.append((global_rank, engine)) + self.all_engines[i] = engine self.num_new_engines = len(local_all_engines) + if self.num_new_engines == 0: + return [], port_cursors + base_port = max(port_cursors.values()) if port_cursors else 15000 addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( cluster_cfg=self.cluster_cfg, From cb9a6b6f44d6dbb90abfd78de3c29bc419f4a52a Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Mon, 13 Apr 2026 19:03:09 -0700 Subject: [PATCH 26/36] update --- .../generation/redesign/sglang_generation.py | 1 + .../generation/redesign/sglang_worker.py | 17 - .../models/generation/redesign/__init__.py | 0 .../models/generation/redesign/conftest.py | 79 +++ .../models/generation/redesign/helpers.py | 170 +++++++ .../models/generation/redesign/pytest.ini | 3 + .../redesign/test_sglang_generation.py | 451 ++++++++++++++++++ .../generation/redesign/test_sglang_launch.py | 99 ++++ .../generation/redesign/test_sglang_router.py | 95 ++++ .../test_sglang_shutdown_and_recover.py | 92 ++++ .../redesign/test_sglang_worker_init.py | 80 ++++ .../redesign/test_sglang_worker_memory.py | 157 ++++++ .../generation/redesign/test_utils_smoke.py | 95 ++++ 13 files changed, 1322 insertions(+), 17 deletions(-) create mode 100644 tests/unit/models/generation/redesign/__init__.py create mode 100644 tests/unit/models/generation/redesign/conftest.py create mode 100644 tests/unit/models/generation/redesign/helpers.py create mode 100644 tests/unit/models/generation/redesign/pytest.ini create mode 100644 tests/unit/models/generation/redesign/test_sglang_generation.py create mode 100644 tests/unit/models/generation/redesign/test_sglang_launch.py create mode 100644 tests/unit/models/generation/redesign/test_sglang_router.py create mode 100644 tests/unit/models/generation/redesign/test_sglang_shutdown_and_recover.py create mode 100644 tests/unit/models/generation/redesign/test_sglang_worker_init.py create mode 100644 tests/unit/models/generation/redesign/test_sglang_worker_memory.py create mode 100644 tests/unit/models/generation/redesign/test_utils_smoke.py diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 8a07d0c336..38ded88980 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -364,6 +364,7 @@ def shutdown(self) -> bool: except Exception as e: logger.warning(f"Engine shutdown failed: {e}") ok = False + self.all_engines = [None] * len(self.all_engines) if self._router_actor is not None: try: diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index ee07040996..24664aedc5 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -357,26 +357,9 @@ def resume_memory_kv_cache_and_cuda_graph(self): tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] ) - def get_weights_by_name(self, name: str, truncate_size: int = 0): - """Get a model parameter by name from the SGLang server. - - Args: - name: Fully qualified parameter name (e.g. "model.layers.0.self_attn.q_proj.weight"). - truncate_size: Maximum number of elements to return per tensor. - 0 means return the full tensor. - - Returns: - Server response JSON containing the parameter values. - """ - return self._make_request( - "get_weights_by_name", - {"name": name, "truncate_size": truncate_size}, - ) - def check_weights(self, action: str): return self._make_request("weights_checker", {"action": action}) - def init_weights_update_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): return self._make_request( "init_weights_update_group", diff --git a/tests/unit/models/generation/redesign/__init__.py b/tests/unit/models/generation/redesign/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/models/generation/redesign/conftest.py b/tests/unit/models/generation/redesign/conftest.py new file mode 100644 index 0000000000..871503c751 --- /dev/null +++ b/tests/unit/models/generation/redesign/conftest.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conftest for redesign tests — real Ray, real SGLang. + +Tests in this directory exercise the redesign/ modules using real Ray actors +and real SGLang servers. The conftest stubs non-sglang heavy dependencies +but lets sglang imports resolve naturally against the installed package. +""" + +import os +import sys +from unittest.mock import MagicMock + +# Set default GPU devices before any CUDA/Ray initialisation. +# The remote cluster reserves GPUs 4-7 for this work. +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "4,5,6,7") + +# Use system Python for all Ray actors (uv not configured in container). +os.environ.setdefault("NEMO_RL_PY_EXECUTABLES_SYSTEM", "1") + +# Ensure the test directory is on sys.path so helpers.py is importable. +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +# --------------------------------------------------------------------------- +# Stub heavy modules NOT installed in the sglang test environment. +# sglang is NOT stubbed — we test against a real server. +# --------------------------------------------------------------------------- +_STUB_MODULES = [ + "decord", + "vllm", + "vllm.sampling_params", + "vllm.lora", + "vllm.lora.request", + "wandb", +] +for _mod in _STUB_MODULES: + sys.modules.setdefault(_mod, MagicMock()) + +import pytest +import ray + +from nemo_rl.models.generation.redesign.sglang_router import RouterActor + + +# --------------------------------------------------------------------------- +# Session-scoped fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="session") +def ray_cluster(): + """Initialise Ray once for the entire test session.""" + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.fixture(scope="session") +def router(ray_cluster): + """Start a real sglang router that lives for the session.""" + actor = RouterActor.remote() + ip, port = ray.get(actor.start.remote({})) + yield {"actor": actor, "ip": ip, "port": port} + try: + ray.get(actor.stop.remote()) + except Exception: + pass + ray.kill(actor) diff --git a/tests/unit/models/generation/redesign/helpers.py b/tests/unit/models/generation/redesign/helpers.py new file mode 100644 index 0000000000..0ed09d3039 --- /dev/null +++ b/tests/unit/models/generation/redesign/helpers.py @@ -0,0 +1,170 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for redesign tests. + +Kept in a regular module (not conftest.py) so test files can import it +directly. conftest.py also imports from here for fixture definitions. +""" + +import os + +import ray + +from nemo_rl.models.generation.redesign.misc import NOSET_VISIBLE_DEVICES_ENV_VARS_LIST +from nemo_rl.models.generation.redesign.ray_utils import ( + find_available_port, + get_host_info, +) +from nemo_rl.models.generation.redesign.sglang_worker import SGLangGenerationWorker + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +MODEL_PATH = "Qwen/Qwen3-0.6B" + +# Qwen3-0.6B model dimensions (verified against HuggingFace config) +HIDDEN_SIZE = 1024 +INTERMEDIATE_SIZE = 3072 +NUM_ATTENTION_HEADS = 16 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +QKV_OUTPUT_DIM = NUM_ATTENTION_HEADS * HEAD_DIM + 2 * NUM_KV_HEADS * HEAD_DIM # 4096 + + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- +def make_cluster_cfg(gpus_per_node=4): + return {"gpus_per_node": gpus_per_node} + + +def make_sglang_cfg( + model_path=MODEL_PATH, + tp_size=1, + num_gpus=4, + router_ip=None, + router_port=None, +): + return { + "sglang_cfg": { + "model_path": model_path, + "random_seed": 42, + "enable_memory_saver": False, + "dp_size": 1, + "pp_size": 1, + "ep_size": 1, + "skip_server_warmup": True, + "dtype": "bfloat16", + "context_length": 1024, + "log_level": "warning", + "disable_piecewise_cuda_graph": True, + }, + "sglang_server": { + "num_gpus": num_gpus, + "num_gpus_per_engine": tp_size, + "needs_offload": True, + "sglang_server_concurrency": 64, + }, + "sglang_router": { + "sglang_router_ip": router_ip, + "sglang_router_port": router_port, + }, + } + + +def make_actor_env_vars(): + """Build env-vars dict for SGLang worker actors.""" + env_vars = {name: "1" for name in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST} + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd: + env_vars["CUDA_VISIBLE_DEVICES"] = cvd + return env_vars + + +def create_worker(router_info, base_gpu_id=0, tp_size=1, rank=0): + """Create and initialise a real SGLangGenerationWorker Ray actor. + + Returns the actor handle after ``init`` completes. + """ + cluster_cfg = make_cluster_cfg() + sglang_cfg = make_sglang_cfg( + tp_size=tp_size, + router_ip=router_info["ip"], + router_port=router_info["port"], + ) + + worker = SGLangGenerationWorker.options( + num_cpus=0.2, + num_gpus=0.2, + runtime_env={"env_vars": make_actor_env_vars()}, + ).remote( + cluster_cfg, + sglang_cfg, + rank=rank, + base_gpu_id=base_gpu_id, + num_gpus_per_engine=tp_size, + ) + + host_ip = get_host_info()[1] + port = find_available_port(30000 + rank * 1000) + nccl_port = find_available_port(40000 + rank * 1000) + dist_init_port = find_available_port(50000 + rank * 1000) + + ray.get( + worker.init.remote( + dist_init_addr=f"{host_ip}:{dist_init_port}", + port=port, + nccl_port=nccl_port, + router_ip=router_info["ip"], + router_port=router_info["port"], + ) + ) + return worker + + +def make_generation_sampling_params( + max_new_tokens=16, temperature=0.0, top_p=1.0, stop=None, +): + """Build sampling_params dict for generate_one_sample / router /generate.""" + params = { + "temperature": temperature, + "max_new_tokens": max_new_tokens, + "top_p": top_p, + "no_stop_trim": True, + "spaces_between_special_tokens": False, + } + if stop is not None: + params["stop"] = stop + return params + + +# --------------------------------------------------------------------------- +# HTTP helpers for tests that want an explicit status-code check +# --------------------------------------------------------------------------- +def post_and_assert_200(base_url, endpoint, payload=None): + """POST ``payload`` to ``{base_url}/{endpoint}`` and assert HTTP 200. + + Tests that exercise ``release_memory_occupation`` / ``resume_memory_occupation`` + use this instead of ``_make_request`` so the 200 check is visible in the + test body (``_make_request`` consumes the status code inside + ``raise_for_status()`` and returns only the parsed JSON). + """ + import requests + + resp = requests.post(f"{base_url}/{endpoint}", json=payload or {}) + assert resp.status_code == 200, ( + f"POST {endpoint} expected 200, got {resp.status_code}: {resp.text}" + ) + return resp.json() diff --git a/tests/unit/models/generation/redesign/pytest.ini b/tests/unit/models/generation/redesign/pytest.ini new file mode 100644 index 0000000000..adf05beca2 --- /dev/null +++ b/tests/unit/models/generation/redesign/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + sglang: tests requiring sglang and GPU resources diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py new file mode 100644 index 0000000000..ea44133ec7 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -0,0 +1,451 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generation tests using a real SGLangGeneration instance. + +Spins up a real RayVirtualCluster + SGLangGeneration (router + workers) +and tests ``generate()``, ``generate_async()``, and the underlying +``generate_one_sample()`` function against a live Qwen3-0.6B model. + +Model: Qwen/Qwen3-0.6B (4 GPUs, TP=1 × 4 engines) +""" + +import asyncio +import gc +from copy import deepcopy + +import pytest +import ray +import torch + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.models.generation.redesign.sglang_generation import ( + SGLangGeneration, + generate_one_sample, +) +from nemo_rl.models.generation.interfaces import GenerationDatumSpec + +from helpers import ( + MODEL_PATH, + make_generation_sampling_params, + post_and_assert_200, +) + +pytestmark = pytest.mark.sglang + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +PAD_TOKEN_ID = 151643 +EOS_TOKEN_ID = 151645 + + +# --------------------------------------------------------------------------- +# SGLang config for SGLangGeneration (mirrors existing test pattern) +# --------------------------------------------------------------------------- +def _make_sglang_generation_cfg(pad_token_id=PAD_TOKEN_ID): + return { + "backend": "sglang", + "model_name": MODEL_PATH, + "model_path": MODEL_PATH, + "tokenizer": {"name": MODEL_PATH}, + "dtype": "bfloat16", + "max_new_tokens": 16, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": [EOS_TOKEN_ID], + "stop_strings": None, + "_pad_token_id": pad_token_id, + "sglang_cfg": { + "model_path": MODEL_PATH, + "dtype": "bfloat16", + "random_seed": 42, + "context_length": 1024, + "log_level": "warning", + "skip_server_warmup": True, + "enable_memory_saver": False, + "dp_size": 1, + "pp_size": 1, + "ep_size": 1, + "disable_piecewise_cuda_graph": True, + }, + "sglang_server": { + "num_gpus": 4, + "num_gpus_per_engine": 1, + "needs_offload": True, + "sglang_server_concurrency": 64, + }, + "sglang_router": { + "sglang_router_ip": None, + "sglang_router_port": None, + }, + "sglang_kwargs": {}, + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def tokenizer(): + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + +@pytest.fixture(scope="module") +def sglang_gen(ray_cluster, tokenizer): + """Real SGLangGeneration: RayVirtualCluster → router → 4×TP=1 engines.""" + cluster = RayVirtualCluster( + bundle_ct_per_node_list=[4], + use_gpus=True, + max_colocated_worker_groups=1, + num_gpus_per_node=4, + name="gen-test", + ) + cluster_cfg = {"gpus_per_node": 4, "num_nodes": 1} + sglang_cfg = _make_sglang_generation_cfg(pad_token_id=tokenizer.pad_token_id) + + gen = SGLangGeneration(cluster, cluster_cfg, sglang_cfg) + yield gen + try: + gen.shutdown() + except Exception: + pass + try: + cluster.shutdown() + except Exception: + pass + gc.collect() + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_input(tokenizer, prompt, pad_length=None): + """Tokenize a prompt → BatchedDataDict for generate().""" + token_ids = tokenizer.encode(prompt) + input_length = len(token_ids) + if pad_length and pad_length > input_length: + token_ids = token_ids + [tokenizer.pad_token_id] * (pad_length - input_length) + return BatchedDataDict( + { + "input_ids": torch.tensor([token_ids], dtype=torch.long), + "input_lengths": torch.tensor([input_length], dtype=torch.long), + } + ) + + +def _make_batch(tokenizer, prompts, pad_length=None): + """Tokenize multiple prompts → single BatchedDataDict.""" + all_ids = [] + all_lengths = [] + max_len = 0 + for p in prompts: + ids = tokenizer.encode(p) + all_ids.append(ids) + all_lengths.append(len(ids)) + max_len = max(max_len, len(ids)) + + if pad_length: + max_len = max(max_len, pad_length) + + padded = [] + for ids in all_ids: + padded.append(ids + [tokenizer.pad_token_id] * (max_len - len(ids))) + + return BatchedDataDict( + { + "input_ids": torch.tensor(padded, dtype=torch.long), + "input_lengths": torch.tensor(all_lengths, dtype=torch.long), + } + ) + + +# =================================================================== +# Tests: SGLangGeneration.generate() +# =================================================================== + + +def test_generate_returns_batched_data_dict(sglang_gen, tokenizer): + """generate() returns BatchedDataDict with all required output keys.""" + data = _make_input(tokenizer, "Hello") + result = sglang_gen.generate(data, greedy=True) + + for key in ["output_ids", "logprobs", "generation_lengths", + "unpadded_sequence_lengths", "truncated"]: + assert key in result, f"Missing key: {key}" + + +def test_generate_output_ids_shape(sglang_gen, tokenizer): + """output_ids has shape (batch_size, total_length) with correct padding.""" + data = _make_input(tokenizer, "The capital of France is") + result = sglang_gen.generate(data, greedy=True) + + assert result["output_ids"].dim() == 2 + assert result["output_ids"].shape[0] == 1 # batch_size + gen_len = result["generation_lengths"][0].item() + input_len = data["input_lengths"][0].item() + assert result["unpadded_sequence_lengths"][0].item() == input_len + gen_len + + +def test_generate_produces_nonzero_tokens(sglang_gen, tokenizer): + """Generated tokens are non-zero and non-pad.""" + data = _make_input(tokenizer, "What is 2 plus 2?") + result = sglang_gen.generate(data, greedy=True) + + gen_len = result["generation_lengths"][0].item() + assert gen_len > 0, "No tokens generated" + + input_len = data["input_lengths"][0].item() + generated = result["output_ids"][0, input_len:input_len + gen_len] + assert (generated != 0).all(), "Generated tokens contain zeros" + assert (generated != PAD_TOKEN_ID).all(), "Generated tokens contain pad" + + +def test_generate_greedy_determinism(sglang_gen, tokenizer): + """Same prompt + greedy=True → identical output_ids across two calls.""" + data = _make_input(tokenizer, "Once upon a time") + r1 = sglang_gen.generate(data, greedy=True) + r2 = sglang_gen.generate(data, greedy=True) + + assert torch.equal(r1["output_ids"], r2["output_ids"]), ( + "Greedy generation is not deterministic" + ) + + +def test_generate_truncation_flag(sglang_gen, tokenizer): + """When max_new_tokens is small, truncated=True.""" + # Temporarily reduce max_new_tokens + orig = sglang_gen.sglang_cfg["max_new_tokens"] + sglang_gen.sglang_cfg["max_new_tokens"] = 1 + try: + data = _make_input(tokenizer, "Tell me a very long story about dragons and") + result = sglang_gen.generate(data, greedy=True) + gen_len = result["generation_lengths"][0].item() + assert gen_len == 1, f"Expected 1 token, got {gen_len}" + assert result["truncated"][0].item() is True, "Expected truncated=True" + finally: + sglang_gen.sglang_cfg["max_new_tokens"] = orig + + +def test_generate_logprobs_valid(sglang_gen, tokenizer): + """Logprobs are finite, non-positive at generated positions.""" + data = _make_input(tokenizer, "Hello world") + result = sglang_gen.generate(data, greedy=True) + + gen_len = result["generation_lengths"][0].item() + input_len = data["input_lengths"][0].item() + lps = result["logprobs"][0, input_len:input_len + gen_len] + + assert torch.isfinite(lps).all(), "Logprobs contain NaN or Inf" + assert (lps <= 0.0).all(), "Logprobs should be non-positive" + + +def test_generate_respects_max_new_tokens(sglang_gen, tokenizer): + """generation_lengths ≤ max_new_tokens for all samples.""" + data = _make_input(tokenizer, "Count from 1 to 100:") + result = sglang_gen.generate(data, greedy=True) + + max_new = sglang_gen.sglang_cfg["max_new_tokens"] + gen_len = result["generation_lengths"][0].item() + assert gen_len <= max_new, f"gen_len={gen_len} > max_new_tokens={max_new}" + + +def test_generate_batch_multiple_samples(sglang_gen, tokenizer): + """Batch of 3 prompts: all produce valid output.""" + prompts = [ + "Hello, my name is", + "The capital of France is", + "What is 2 plus 2?", + ] + data = _make_batch(tokenizer, prompts) + result = sglang_gen.generate(data, greedy=True) + + assert result["output_ids"].shape[0] == 3 + assert result["generation_lengths"].shape[0] == 3 + for i in range(3): + gen_len = result["generation_lengths"][i].item() + assert gen_len > 0, f"Sample {i} generated 0 tokens" + + +def test_generate_empty_input(sglang_gen): + """Empty batch → empty BatchedDataDict with zero-size tensors.""" + data = BatchedDataDict( + { + "input_ids": torch.zeros((0, 0), dtype=torch.long), + "input_lengths": torch.zeros(0, dtype=torch.long), + } + ) + result = sglang_gen.generate(data, greedy=True) + assert result["output_ids"].shape[0] == 0 + + +def test_generate_with_stop_strings(sglang_gen, tokenizer): + """Stop string causes early termination.""" + orig_stop = sglang_gen.sglang_cfg.get("stop_strings") + sglang_gen.sglang_cfg["stop_strings"] = ["\n"] + try: + data = _make_input(tokenizer, "List:\n1. Apple\n2.") + result = sglang_gen.generate(data, greedy=True) + gen_len = result["generation_lengths"][0].item() + max_new = sglang_gen.sglang_cfg["max_new_tokens"] + # If stop string triggered, generation should be shorter than max + # (this is a soft check — the model might produce \n on first token) + assert gen_len <= max_new + finally: + sglang_gen.sglang_cfg["stop_strings"] = orig_stop + + +# =================================================================== +# Tests: SGLangGeneration.generate_async() +# =================================================================== + + +def test_generate_async_yields_single_sample(sglang_gen, tokenizer): + """generate_async() with batch_size=1 yields (0, BatchedDataDict).""" + data = _make_input(tokenizer, "Hello") + + async def _run(): + results = [] + async for idx, batch in sglang_gen.generate_async(data, greedy=True): + results.append((idx, batch)) + return results + + results = asyncio.run(_run()) + assert len(results) == 1 + idx, batch = results[0] + assert idx == 0 + assert "output_ids" in batch + assert batch["generation_lengths"][0].item() > 0 + + +def test_generate_async_output_matches_generate(sglang_gen, tokenizer): + """Same prompt, greedy: generate() and generate_async() produce same tokens.""" + data = _make_input(tokenizer, "The answer is") + sync_result = sglang_gen.generate(data, greedy=True) + + async def _run(): + results = [] + async for _, batch in sglang_gen.generate_async(data, greedy=True): + results.append(batch) + return results[0] + + async_result = asyncio.run(_run()) + + sync_len = sync_result["generation_lengths"][0].item() + async_len = async_result["generation_lengths"][0].item() + assert sync_len == async_len, f"sync={sync_len} vs async={async_len}" + + input_len = data["input_lengths"][0].item() + sync_tokens = sync_result["output_ids"][0, input_len:input_len + sync_len] + async_tokens = async_result["output_ids"][0, input_len:input_len + async_len] + assert torch.equal(sync_tokens, async_tokens), ( + "generate() and generate_async() produced different tokens" + ) + + +# =================================================================== +# Tests: generate_one_sample() — the underlying async function +# =================================================================== + + +def test_generate_one_sample_returns_correct_tuple(sglang_gen, tokenizer): + """generate_one_sample() returns (index, tokens, logprobs, truncated).""" + sp = make_generation_sampling_params(max_new_tokens=5, temperature=0.0) + input_ids = tokenizer.encode("The capital of France is") + + result = asyncio.run( + generate_one_sample( + sglang_gen.router_ip, sglang_gen.router_port, sp, input_ids, index=42 + ) + ) + + assert len(result) == 4 + idx, tokens, logprobs, truncated = result + assert idx == 42 + assert isinstance(tokens, list) and len(tokens) > 0 + assert isinstance(logprobs, list) and len(logprobs) == len(tokens) + assert isinstance(truncated, bool) + assert all(isinstance(t, int) for t in tokens) + assert all(isinstance(lp, float) for lp in logprobs) + + +def test_generate_after_memory_cycle(sglang_gen, tokenizer): + """Generate → offload/onload → generate → same greedy output.""" + data = _make_input(tokenizer, "Two plus two equals") + r_before = sglang_gen.generate(data, greedy=True) + + # Offload and onload weights + KV on all engines + for engine in sglang_gen.engines: + ray.get(engine.release_memory_weights.remote()) + ray.get(engine.release_memory_kv_cache_and_cuda_graph.remote()) + ray.get(engine.resume_memory_weights.remote()) + ray.get(engine.resume_memory_kv_cache_and_cuda_graph.remote()) + + r_after = sglang_gen.generate(data, greedy=True) + + assert torch.equal(r_before["output_ids"], r_after["output_ids"]), ( + "Generation output changed after memory cycle" + ) + + +def test_generate_after_memory_cycle_via_http_200(sglang_gen, tokenizer): + """Generate → offload/onload via direct HTTP (asserting 200) → generate → same greedy output. + + Equivalent to ``test_generate_after_memory_cycle`` but drives the cycle by + POSTing directly to each engine's ``release_memory_occupation`` / + ``resume_memory_occupation`` endpoint and asserting + ``resp.status_code == 200`` on every call (via ``post_and_assert_200``). + This avoids ``_make_request``, which hides the status code behind + ``raise_for_status()``. + """ + data = _make_input(tokenizer, "Two plus two equals") + r_before = sglang_gen.generate(data, greedy=True) + + for engine in sglang_gen.engines: + base_url = ray.get(engine.get_base_url.remote()) + assert base_url is not None + + # Release weights (flush_cache first, mirroring release_memory_occupation) + ray.get(engine.flush_cache.remote()) + post_and_assert_200( + base_url, "release_memory_occupation", {"tags": ["weights"]} + ) + # Release KV cache + CUDA graphs + ray.get(engine.flush_cache.remote()) + post_and_assert_200( + base_url, + "release_memory_occupation", + {"tags": ["kv_cache", "cuda_graph"]}, + ) + # Resume weights + post_and_assert_200( + base_url, "resume_memory_occupation", {"tags": ["weights"]} + ) + # Resume KV cache + CUDA graphs + post_and_assert_200( + base_url, + "resume_memory_occupation", + {"tags": ["kv_cache", "cuda_graph"]}, + ) + + r_after = sglang_gen.generate(data, greedy=True) + + assert torch.equal(r_before["output_ids"], r_after["output_ids"]), ( + "Generation output changed after HTTP-driven memory cycle" + ) diff --git a/tests/unit/models/generation/redesign/test_sglang_launch.py b/tests/unit/models/generation/redesign/test_sglang_launch.py new file mode 100644 index 0000000000..cf43ce4da6 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_sglang_launch.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SGLangGeneration init chain — multi-worker orchestration, +router integration, and the Lock actor. + +Instead of instantiating the full SGLangGeneration class (which needs +RayVirtualCluster), we test each component that __init__ wires together: + • multiple SGLangGenerationWorker actors started in parallel + • all workers registered with the real router + • the Lock Ray actor used for rollout_engine_lock +""" + +import pytest +import ray +import requests + +from nemo_rl.models.generation.redesign.ray_utils import Lock + +from helpers import create_worker + +pytestmark = pytest.mark.sglang + + +# ------------------------------------------------------------------ +# Multi-worker orchestration +# ------------------------------------------------------------------ +@pytest.fixture(scope="module") +def two_workers(ray_cluster, router): + """Start two TP=1 workers on GPUs 0 and 1.""" + w0 = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + w1 = create_worker(router, base_gpu_id=1, tp_size=1, rank=1) + yield [w0, w1] + for w in [w0, w1]: + try: + ray.get(w.shutdown.remote()) + except Exception: + pass + + +def test_multiple_workers_init(two_workers): + """Two workers start successfully on separate GPUs.""" + for w in two_workers: + assert ray.get(w.health_generate.remote()) is True + + +def test_workers_register_with_router(two_workers, router): + """Both workers appear in the router's /workers list.""" + resp = requests.get( + f"http://{router['ip']}:{router['port']}/workers", timeout=10 + ) + assert resp.status_code == 200 + workers_list = resp.json().get("workers", []) + assert len(workers_list) >= 2 + + +def test_workers_have_distinct_urls(two_workers): + """Each worker reports a unique base URL.""" + urls = [ray.get(w.get_base_url.remote()) for w in two_workers] + assert len(set(urls)) == 2 + for url in urls: + assert url.startswith("http://") + + +# ------------------------------------------------------------------ +# Lock actor +# ------------------------------------------------------------------ +def test_lock_actor_acquire_release(ray_cluster): + """Lock.acquire / release round-trip works.""" + lock = Lock.options(num_cpus=0.1, num_gpus=0).remote() + try: + assert ray.get(lock.acquire.remote()) is True + ray.get(lock.release.remote()) + finally: + ray.kill(lock) + + +def test_lock_actor_mutual_exclusion(ray_cluster): + """A second acquire fails while the lock is held.""" + lock = Lock.options(num_cpus=0.1, num_gpus=0).remote() + try: + assert ray.get(lock.acquire.remote()) is True + assert ray.get(lock.acquire.remote()) is False # already held + ray.get(lock.release.remote()) + assert ray.get(lock.acquire.remote()) is True # free again + ray.get(lock.release.remote()) + finally: + ray.kill(lock) diff --git a/tests/unit/models/generation/redesign/test_sglang_router.py b/tests/unit/models/generation/redesign/test_sglang_router.py new file mode 100644 index 0000000000..354e861688 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_sglang_router.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RouterActor lifecycle — start, port allocation, stop. + +All tests use a real Ray cluster and a real sglang_router subprocess. +Each test creates its own RouterActor to avoid cross-test interference. +""" + +import pytest +import ray +import requests + +from nemo_rl.models.generation.redesign.ray_utils import find_available_port +from nemo_rl.models.generation.redesign.sglang_router import RouterActor + +pytestmark = pytest.mark.sglang + + +def _start_and_cleanup(actor, router_cfg): + """Start a router, return (ip, port), register cleanup on failure.""" + ip, port = ray.get(actor.start.remote(router_cfg)) + return ip, port + + +def _stop_router(actor): + try: + ray.get(actor.stop.remote()) + except Exception: + pass + ray.kill(actor) + + +def test_start_returns_ip_and_port(ray_cluster): + """RouterActor.start returns a (str, int) tuple.""" + actor = RouterActor.remote() + try: + ip, port = _start_and_cleanup(actor, {}) + assert isinstance(ip, str) and len(ip) > 0 + assert isinstance(port, int) and port > 0 + finally: + _stop_router(actor) + + +def test_start_uses_configured_port(ray_cluster): + """When sglang_router_port is set, the router uses that exact port.""" + configured_port = find_available_port(9000) + actor = RouterActor.remote() + try: + ip, port = _start_and_cleanup(actor, {"sglang_router_port": configured_port}) + assert port == configured_port + finally: + _stop_router(actor) + + +def test_start_finds_port_when_not_configured(ray_cluster): + """When sglang_router_port is None, the router picks one automatically.""" + actor = RouterActor.remote() + try: + ip, port = _start_and_cleanup(actor, {}) + assert isinstance(port, int) and port > 0 + finally: + _stop_router(actor) + + +def test_stop_terminates_process(ray_cluster): + """stop() completes without error after a successful start.""" + actor = RouterActor.remote() + ray.get(actor.start.remote({})) + ray.get(actor.stop.remote()) # should not raise + ray.kill(actor) + + +def test_router_serves_workers_endpoint(ray_cluster): + """A started router exposes the /workers HTTP endpoint.""" + actor = RouterActor.remote() + try: + ip, port = _start_and_cleanup(actor, {}) + resp = requests.get(f"http://{ip}:{port}/workers", timeout=10) + assert resp.status_code == 200 + data = resp.json() + assert "workers" in data + finally: + _stop_router(actor) diff --git a/tests/unit/models/generation/redesign/test_sglang_shutdown_and_recover.py b/tests/unit/models/generation/redesign/test_sglang_shutdown_and_recover.py new file mode 100644 index 0000000000..4073ffc1a5 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_sglang_shutdown_and_recover.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for worker shutdown, router un-registration, and recovery. + +Each test creates a **fresh** worker so that shutdown / crash is +non-destructive to the rest of the session. +""" + +import time + +import pytest +import ray +import requests + +from helpers import create_worker + +pytestmark = pytest.mark.sglang + + +def _get_worker_count(router): + """Get the number of workers registered with the router.""" + resp = requests.get( + f"http://{router['ip']}:{router['port']}/workers", timeout=10 + ) + return len(resp.json().get("workers", [])) + + +def _wait_for_worker_count(router, expected, timeout=15): + """Poll until the router reports the expected worker count.""" + deadline = time.time() + timeout + while time.time() < deadline: + if _get_worker_count(router) == expected: + return True + time.sleep(1) + return False + + +# ------------------------------------------------------------------ +# shutdown +# ------------------------------------------------------------------ +def test_shutdown_worker(ray_cluster, router): + """Worker shutdown completes without error.""" + worker = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + ray.get(worker.shutdown.remote()) # should not raise + + +def test_shutdown_unregisters_from_router(ray_cluster, router): + """After shutdown the worker is no longer in the router's list.""" + count_before_create = _get_worker_count(router) + worker = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + + # Wait for the worker to appear in the router + assert _wait_for_worker_count(router, count_before_create + 1), ( + f"Worker never appeared in router (expected {count_before_create + 1}, " + f"got {_get_worker_count(router)})" + ) + + ray.get(worker.shutdown.remote()) + + # Wait for the worker to disappear + assert _wait_for_worker_count(router, count_before_create), ( + f"Worker still in router after shutdown (expected {count_before_create}, " + f"got {_get_worker_count(router)})" + ) + + +def test_new_worker_after_shutdown(ray_cluster, router): + """A new worker can be created on the same GPU after shutdown.""" + w1 = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + ray.get(w1.shutdown.remote()) + + w2 = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + assert ray.get(w2.health_generate.remote()) is True + ray.get(w2.shutdown.remote()) + + +def test_simulate_crash(ray_cluster, router): + """simulate_crash (which calls shutdown) does not raise.""" + worker = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + ray.get(worker.simulate_crash.remote()) diff --git a/tests/unit/models/generation/redesign/test_sglang_worker_init.py b/tests/unit/models/generation/redesign/test_sglang_worker_init.py new file mode 100644 index 0000000000..867f26e732 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_sglang_worker_init.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SGLangGenerationWorker.init — server launch and router registration. + +Uses a real Ray cluster, a real sglang router, and a real SGLang server +(Qwen3-0.6B, TP=1). A module-scoped worker is shared across all tests +in this file. +""" + +import pytest +import ray +import requests + +from helpers import create_worker + +pytestmark = pytest.mark.sglang + + +@pytest.fixture(scope="module") +def worker(ray_cluster, router): + """Create a single TP=1 worker for this module's tests.""" + w = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + yield w + try: + ray.get(w.shutdown.remote()) + except Exception: + pass + + +# ------------------------------------------------------------------ +def test_init_server_healthy(worker): + """After init, the underlying SGLang server is healthy.""" + result = ray.get(worker.health_generate.remote()) + assert result is True + + +def test_init_sets_base_url(worker): + """get_base_url returns a valid http:// URL after init.""" + url = ray.get(worker.get_base_url.remote()) + assert url is not None + assert url.startswith("http://") + + +def test_init_registers_with_router(worker, router): + """The worker registers itself with the session router on init.""" + resp = requests.get( + f"http://{router['ip']}:{router['port']}/workers", timeout=10 + ) + assert resp.status_code == 200 + workers_list = resp.json().get("workers", []) + # At least one worker should be registered + assert len(workers_list) >= 1 + + +def test_health_generate_returns_true(worker): + """health_generate succeeds multiple times (idempotent).""" + for _ in range(3): + assert ray.get(worker.health_generate.remote()) is True + + +def test_get_gpu_uuids(worker): + """get_gpu_uuids returns a non-empty list of GPU-* strings.""" + uuids = ray.get(worker.get_gpu_uuids.remote()) + assert isinstance(uuids, list) + assert len(uuids) >= 1 + for uuid in uuids: + assert isinstance(uuid, str) + assert uuid.startswith("GPU-") diff --git a/tests/unit/models/generation/redesign/test_sglang_worker_memory.py b/tests/unit/models/generation/redesign/test_sglang_worker_memory.py new file mode 100644 index 0000000000..b777c97feb --- /dev/null +++ b/tests/unit/models/generation/redesign/test_sglang_worker_memory.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SGLangGenerationWorker memory management: +flush_cache, release_memory_occupation, resume_memory_occupation. + +Uses a real SGLang server (Qwen3-0.6B, TP=1). Each test is fully +self-contained — it leaves the server in the same state it found it. +""" + +import pytest +import ray + +from helpers import create_worker, post_and_assert_200 + +pytestmark = pytest.mark.sglang + + +@pytest.fixture(scope="module") +def worker(ray_cluster, router): + """A TP=1 worker dedicated to memory tests.""" + w = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) + yield w + try: + ray.get(w.shutdown.remote()) + except Exception: + pass + + +# ------------------------------------------------------------------ +# flush_cache +# ------------------------------------------------------------------ +def test_flush_cache_success(worker): + """flush_cache returns without error on a healthy server.""" + ray.get(worker.flush_cache.remote()) + + +# ------------------------------------------------------------------ +# release / resume — weights (self-contained) +# ------------------------------------------------------------------ +def test_release_and_resume_memory_weights(worker): + """release_memory_weights followed by resume succeeds.""" + ray.get(worker.release_memory_weights.remote()) + ray.get(worker.resume_memory_weights.remote()) + + +# ------------------------------------------------------------------ +# release / resume — KV cache + CUDA graphs (self-contained) +# ------------------------------------------------------------------ +def test_release_and_resume_memory_kv_cache_and_cuda_graph(worker): + """release then resume KV cache + CUDA graphs succeeds.""" + ray.get(worker.release_memory_kv_cache_and_cuda_graph.remote()) + ray.get(worker.resume_memory_kv_cache_and_cuda_graph.remote()) + + +# ------------------------------------------------------------------ +# full offload / onload cycle +# ------------------------------------------------------------------ +def test_full_offload_onload_cycle(worker): + """Full offload (weights then KV) then onload (weights then KV) works.""" + ray.get(worker.release_memory_weights.remote()) + ray.get(worker.release_memory_kv_cache_and_cuda_graph.remote()) + ray.get(worker.resume_memory_weights.remote()) + ray.get(worker.resume_memory_kv_cache_and_cuda_graph.remote()) + + +def test_health_after_memory_cycle(worker): + """health_generate passes after a full offload / onload cycle.""" + ray.get(worker.release_memory_weights.remote()) + ray.get(worker.release_memory_kv_cache_and_cuda_graph.remote()) + ray.get(worker.resume_memory_weights.remote()) + ray.get(worker.resume_memory_kv_cache_and_cuda_graph.remote()) + assert ray.get(worker.health_generate.remote()) is True + + +def test_flush_cache_after_resume(worker): + """flush_cache succeeds after a release → resume round-trip.""" + ray.get(worker.release_memory_weights.remote()) + ray.get(worker.resume_memory_weights.remote()) + ray.get(worker.flush_cache.remote()) + + +# ------------------------------------------------------------------ +# Equivalent tests using _make_request directly — verify HTTP 200 +# ------------------------------------------------------------------ +def test_offload_onload_via_http_200(worker): + """Full offload/onload cycle driven by direct HTTP POST, asserting 200. + + Uses ``post_and_assert_200`` (which checks ``resp.status_code == 200`` + explicitly) rather than ``_make_request`` — ``_make_request`` throws the + status code away inside ``raise_for_status()`` so callers cannot inspect it. + """ + base_url = ray.get(worker.get_base_url.remote()) + assert base_url is not None + + # Release weights (flush_cache first, mirroring release_memory_occupation) + ray.get(worker.flush_cache.remote()) + post_and_assert_200( + base_url, "release_memory_occupation", {"tags": ["weights"]} + ) + # Release KV cache + CUDA graphs + ray.get(worker.flush_cache.remote()) + post_and_assert_200( + base_url, "release_memory_occupation", {"tags": ["kv_cache", "cuda_graph"]} + ) + # Resume weights + post_and_assert_200( + base_url, "resume_memory_occupation", {"tags": ["weights"]} + ) + # Resume KV cache + CUDA graphs + post_and_assert_200( + base_url, "resume_memory_occupation", {"tags": ["kv_cache", "cuda_graph"]} + ) + assert ray.get(worker.health_generate.remote()) is True + + +def test_double_offload_onload_cycle(worker): + """Two back-to-back full offload/onload cycles via direct HTTP, asserting 200 on every call. + + Exercises the same endpoints twice to catch state leaks across cycles. + Uses ``post_and_assert_200`` so each of the eight POSTs explicitly + verifies ``resp.status_code == 200``. + """ + base_url = ray.get(worker.get_base_url.remote()) + assert base_url is not None + + for _ in range(2): + ray.get(worker.flush_cache.remote()) + post_and_assert_200( + base_url, "release_memory_occupation", {"tags": ["weights"]} + ) + ray.get(worker.flush_cache.remote()) + post_and_assert_200( + base_url, + "release_memory_occupation", + {"tags": ["kv_cache", "cuda_graph"]}, + ) + post_and_assert_200( + base_url, "resume_memory_occupation", {"tags": ["weights"]} + ) + post_and_assert_200( + base_url, + "resume_memory_occupation", + {"tags": ["kv_cache", "cuda_graph"]}, + ) + assert ray.get(worker.health_generate.remote()) is True diff --git a/tests/unit/models/generation/redesign/test_utils_smoke.py b/tests/unit/models/generation/redesign/test_utils_smoke.py new file mode 100644 index 0000000000..cedeecdb35 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_utils_smoke.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Smoke tests for utility modules (ray_utils, misc, async_utils). + +These tests verify basic functionality of helper utilities and do NOT +require a running SGLang server or GPU. +""" + +import asyncio +import multiprocessing + +from nemo_rl.models.generation.redesign.async_utils import AsyncLoopThread +from nemo_rl.models.generation.redesign.misc import ( + MultiprocessingSerializer, + terminate_process, +) +from nemo_rl.models.generation.redesign.ray_utils import ( + _wrap_ipv6, + find_available_port, + get_host_info, + is_port_available, +) + + +# --------------------------------------------------------------------------- +# ray_utils +# --------------------------------------------------------------------------- +def test_find_available_port(): + """find_available_port returns a port that passes is_port_available.""" + port = find_available_port(20000) + assert isinstance(port, int) + assert port > 0 + assert is_port_available(port) + + +def test_wrap_ipv6_noop_for_ipv4(): + """IPv4 addresses are returned unchanged by _wrap_ipv6.""" + assert _wrap_ipv6("192.168.1.1") == "192.168.1.1" + assert _wrap_ipv6("10.0.0.1") == "10.0.0.1" + assert _wrap_ipv6("127.0.0.1") == "127.0.0.1" + + +def test_get_host_info_returns_tuple(): + """get_host_info returns (hostname, ip_address) strings.""" + hostname, ip = get_host_info() + assert isinstance(hostname, str) and len(hostname) > 0 + assert isinstance(ip, str) and len(ip) > 0 + + +# --------------------------------------------------------------------------- +# misc +# --------------------------------------------------------------------------- +def test_serializer_roundtrip(): + """serialize → deserialize returns the original object.""" + obj = {"key": "value", "numbers": [1, 2, 3], "nested": {"a": True}} + serialized = MultiprocessingSerializer.serialize(obj, output_str=True) + assert isinstance(serialized, str) and len(serialized) > 0 + deserialized = MultiprocessingSerializer.deserialize(serialized) + assert deserialized == obj + + +def test_terminate_process_already_dead(): + """terminate_process does not raise when the process is already dead.""" + p = multiprocessing.Process(target=lambda: None) + p.start() + p.join() + # Process has already exited — should be a harmless no-op + terminate_process(p) + + +# --------------------------------------------------------------------------- +# async_utils +# --------------------------------------------------------------------------- +def test_async_loop_thread_runs_coroutine(): + """AsyncLoopThread can submit and await a coroutine.""" + loop = AsyncLoopThread() + + async def coro(): + await asyncio.sleep(0.01) + return 42 + + result = loop.run(coro()) + assert result == 42 From 1d90b3f0cd4b08f83f15321e31578a78f5a2e685 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Mon, 13 Apr 2026 20:03:54 -0700 Subject: [PATCH 27/36] update --- .../redesign/test_sglang_generation.py | 33 ++++++++--- .../redesign/test_sglang_worker_memory.py | 58 +++++++++++++++---- 2 files changed, 72 insertions(+), 19 deletions(-) diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py index ea44133ec7..26930c1d72 100644 --- a/tests/unit/models/generation/redesign/test_sglang_generation.py +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -18,7 +18,11 @@ and tests ``generate()``, ``generate_async()``, and the underlying ``generate_one_sample()`` function against a live Qwen3-0.6B model. -Model: Qwen/Qwen3-0.6B (4 GPUs, TP=1 × 4 engines) +Parametrised over two configurations (both use 4 GPUs total): + • tp4_1server — 1 server × TP=4 + • tp2_2servers — 2 servers × TP=2 + +Model: Qwen/Qwen3-0.6B """ import asyncio @@ -55,7 +59,7 @@ # --------------------------------------------------------------------------- # SGLang config for SGLangGeneration (mirrors existing test pattern) # --------------------------------------------------------------------------- -def _make_sglang_generation_cfg(pad_token_id=PAD_TOKEN_ID): +def _make_sglang_generation_cfg(pad_token_id=PAD_TOKEN_ID, tp_size=1): return { "backend": "sglang", "model_name": MODEL_PATH, @@ -84,7 +88,7 @@ def _make_sglang_generation_cfg(pad_token_id=PAD_TOKEN_ID): }, "sglang_server": { "num_gpus": 4, - "num_gpus_per_engine": 1, + "num_gpus_per_engine": tp_size, "needs_offload": True, "sglang_server_concurrency": 64, }, @@ -106,18 +110,31 @@ def tokenizer(): return AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) -@pytest.fixture(scope="module") -def sglang_gen(ray_cluster, tokenizer): - """Real SGLangGeneration: RayVirtualCluster → router → 4×TP=1 engines.""" +@pytest.fixture( + scope="module", + params=[ + pytest.param({"tp_size": 4, "num_servers": 1}, id="tp4_1server"), + pytest.param({"tp_size": 2, "num_servers": 2}, id="tp2_2servers"), + ], +) +def sglang_gen(request, ray_cluster, tokenizer): + """Real SGLangGeneration: RayVirtualCluster → router → engines. + + Parametrised over tp4_1server (1 server × TP=4) and tp2_2servers + (2 servers × TP=2). All variants use 4 GPUs. + """ + tp_size = request.param["tp_size"] cluster = RayVirtualCluster( bundle_ct_per_node_list=[4], use_gpus=True, max_colocated_worker_groups=1, num_gpus_per_node=4, - name="gen-test", + name=f"gen-test-{request.param['num_servers']}srv-tp{tp_size}", ) cluster_cfg = {"gpus_per_node": 4, "num_nodes": 1} - sglang_cfg = _make_sglang_generation_cfg(pad_token_id=tokenizer.pad_token_id) + sglang_cfg = _make_sglang_generation_cfg( + pad_token_id=tokenizer.pad_token_id, tp_size=tp_size, + ) gen = SGLangGeneration(cluster, cluster_cfg, sglang_cfg) yield gen diff --git a/tests/unit/models/generation/redesign/test_sglang_worker_memory.py b/tests/unit/models/generation/redesign/test_sglang_worker_memory.py index b777c97feb..f99b773367 100644 --- a/tests/unit/models/generation/redesign/test_sglang_worker_memory.py +++ b/tests/unit/models/generation/redesign/test_sglang_worker_memory.py @@ -15,8 +15,16 @@ """Tests for SGLangGenerationWorker memory management: flush_cache, release_memory_occupation, resume_memory_occupation. -Uses a real SGLang server (Qwen3-0.6B, TP=1). Each test is fully -self-contained — it leaves the server in the same state it found it. +Uses a real SGLang server (Qwen3-0.6B), parametrised over two +configurations so the same tests exercise both a single-rank TP=1 +worker and a TP=2 worker: + + • tp1 — 1 worker × TP=1 + • tp2_2workers — 2 workers × TP=2 (the memory tests target worker 0, + but both workers share the router) + +Each test is fully self-contained — it leaves the server in the same +state it found it. """ import pytest @@ -27,15 +35,43 @@ pytestmark = pytest.mark.sglang -@pytest.fixture(scope="module") -def worker(ray_cluster, router): - """A TP=1 worker dedicated to memory tests.""" - w = create_worker(router, base_gpu_id=0, tp_size=1, rank=0) - yield w - try: - ray.get(w.shutdown.remote()) - except Exception: - pass +@pytest.fixture( + scope="module", + params=[ + pytest.param({"tp_size": 1, "num_workers": 1}, id="tp1"), + pytest.param({"tp_size": 2, "num_workers": 2}, id="tp2_2workers"), + ], +) +def worker(request, ray_cluster, router): + """Worker(s) dedicated to memory tests. + + For ``tp1`` a single TP=1 worker is created. For ``tp2_2workers`` + two TP=2 workers share the same router (mirroring the 2-servers + configuration exercised elsewhere); memory tests run against the + first worker but the second is kept alive so the router has the + multi-worker topology in place. + """ + tp_size = request.param["tp_size"] + num_workers = request.param["num_workers"] + + workers = [] + for rank in range(num_workers): + workers.append( + create_worker( + router, + base_gpu_id=rank * tp_size, + tp_size=tp_size, + rank=rank, + ) + ) + + yield workers[0] + + for w in workers: + try: + ray.get(w.shutdown.remote()) + except Exception: + pass # ------------------------------------------------------------------ From 4c6d66fd7f2cc9bf61c9e58cefd9215379034f0b Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 00:51:51 -0700 Subject: [PATCH 28/36] update --- .../redesign/test_sglang_generation.py | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py index 26930c1d72..3a5a8a4116 100644 --- a/tests/unit/models/generation/redesign/test_sglang_generation.py +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -423,13 +423,6 @@ def test_generate_after_memory_cycle(sglang_gen, tokenizer): def test_generate_after_memory_cycle_via_http_200(sglang_gen, tokenizer): """Generate → offload/onload via direct HTTP (asserting 200) → generate → same greedy output. - - Equivalent to ``test_generate_after_memory_cycle`` but drives the cycle by - POSTing directly to each engine's ``release_memory_occupation`` / - ``resume_memory_occupation`` endpoint and asserting - ``resp.status_code == 200`` on every call (via ``post_and_assert_200``). - This avoids ``_make_request``, which hides the status code behind - ``raise_for_status()``. """ data = _make_input(tokenizer, "Two plus two equals") r_before = sglang_gen.generate(data, greedy=True) @@ -466,3 +459,38 @@ def test_generate_after_memory_cycle_via_http_200(sglang_gen, tokenizer): assert torch.equal(r_before["output_ids"], r_after["output_ids"]), ( "Generation output changed after HTTP-driven memory cycle" ) + + +def test_generate_after_memory_cycle_top_level_api(sglang_gen, tokenizer): + """Generate -> top-level offload/onload -> generate -> same greedy output. + """ + data = _make_input(tokenizer, "Two plus two equals") + + r_before = sglang_gen.generate(data, greedy=True) + input_len = data["input_lengths"][0].item() + gen_len_before = r_before["generation_lengths"][0].item() + assert gen_len_before > 0, "generate() before memory cycle produced 0 tokens" + tokens_before = r_before["output_ids"][0, input_len : input_len + gen_len_before] + assert (tokens_before != 0).all(), "before: generated tokens contain zeros" + assert (tokens_before != PAD_TOKEN_ID).all(), "before: generated tokens contain pad" + + # Full offload + onload cycle using the top-level SGLangGeneration API. + sglang_gen.offload_weights() + sglang_gen.offload_kv() + sglang_gen.onload_weights() + sglang_gen.onload_kv() + + r_after = sglang_gen.generate(data, greedy=True) + gen_len_after = r_after["generation_lengths"][0].item() + assert gen_len_after > 0, "generate() after memory cycle produced 0 tokens" + tokens_after = r_after["output_ids"][0, input_len : input_len + gen_len_after] + assert (tokens_after != 0).all(), "after: generated tokens contain zeros" + assert (tokens_after != PAD_TOKEN_ID).all(), "after: generated tokens contain pad" + + assert gen_len_before == gen_len_after, ( + f"Different generation_lengths before vs. after: " + f"before={gen_len_before}, after={gen_len_after}" + ) + assert torch.equal(r_before["output_ids"], r_after["output_ids"]), ( + "Generation output changed after top-level offload/onload cycle" + ) From 012c487b8a81b904ce8e2b51598d8c373e4dc7b6 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 11:09:58 -0700 Subject: [PATCH 29/36] update --- nemo_rl/algorithms/grpo.py | 12 +- nemo_rl/models/generation/redesign/config.py | 5 +- nemo_rl/models/generation/redesign/misc.py | 52 -- .../generation/redesign/sglang_worker.py | 3 +- .../generation/sglang/sglang_copied_utils.py | 106 +--- nemo_rl/models/policy/interfaces.py | 9 +- nemo_rl/models/policy/lm_policy.py | 13 +- nemo_rl/models/policy/redesign_utils.py | 272 ++++++++++ nemo_rl/models/policy/utils.py | 399 +++++++-------- .../workers/dtensor_policy_worker_v2.py | 44 +- .../models/generation/redesign/helpers.py | 1 - .../redesign/test_sglang_generation.py | 2 +- .../generation/redesign/test_utils_smoke.py | 4 +- .../redesign/test_weight_update_real.py | 473 ++++++++++++++++++ 14 files changed, 965 insertions(+), 430 deletions(-) create mode 100644 nemo_rl/models/policy/redesign_utils.py create mode 100644 tests/unit/models/generation/redesign/test_weight_update_real.py diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 447f4f4ebb..9ca0398db4 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1152,15 +1152,11 @@ def refit_policy_generation( ) if isinstance(policy_generation, SGLangGeneration): - sglang_url_to_gpu_uuids = ( - policy_generation.get_sglang_url_to_gpu_uuids() - ) - # Stream weights via HTTP - flush_success = policy_generation.invalidate_kv_cache() - if not flush_success: - print("SGLang KV cache invalidation failed before weight update. ") + # Stream weights to colocated SGLang engines via CUDA IPC over HTTP. + # Engine-i owns global ranks [i*K, (i+1)*K) where K = num_gpus_per_engine. futures_train = policy.stream_weights_via_http( - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + rollout_engines=policy_generation.rollout_engines, + num_gpus_per_engine=policy_generation.num_gpus_per_engine, ) # Wait for all workers to complete ray.get(futures_train) diff --git a/nemo_rl/models/generation/redesign/config.py b/nemo_rl/models/generation/redesign/config.py index e74bce82b0..357462c019 100644 --- a/nemo_rl/models/generation/redesign/config.py +++ b/nemo_rl/models/generation/redesign/config.py @@ -52,7 +52,6 @@ class SglangSpecificArgs(TypedDict): triton_attention_reduce_in_fp32: NotRequired[bool] triton_attention_num_kv_splits: NotRequired[int] num_continuous_decode_steps: NotRequired[int] - enable_memory_saver: NotRequired[bool] allow_auto_truncate: NotRequired[bool] attention_backend: NotRequired[str | None] enable_multimodal: NotRequired[bool] @@ -100,8 +99,12 @@ class SglangSpecificArgs(TypedDict): rollout_health_check_first_wait: NotRequired[int] class SGLangServer(TypedDict): + # needs_offload true --> enable_memory_saver true needs_offload: bool + # for testing purpose. memory_saver + cpu_weight_backup: bool sglang_server_concurrency: int + # total num gpus for inference num_gpus: NotRequired[int] num_gpus_per_engine: NotRequired[int] diff --git a/nemo_rl/models/generation/redesign/misc.py b/nemo_rl/models/generation/redesign/misc.py index 2825761d8e..e2d9ffc5f5 100644 --- a/nemo_rl/models/generation/redesign/misc.py +++ b/nemo_rl/models/generation/redesign/misc.py @@ -1,9 +1,6 @@ import io import logging import multiprocessing -from multiprocessing.reduction import ForkingPickler - -import pybase64 logger = logging.getLogger(__name__) @@ -19,55 +16,6 @@ ] -class MultiprocessingSerializer: # pragma: no cover - """Serialize/deserialize Python objects using ForkingPickler for IPC. - - This class enables serialization of objects (including CUDA tensors with IPC - handles) for transfer between processes via HTTP or other mechanisms. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 - """ - - @staticmethod - def serialize(obj, output_str: bool = False): - """Serialize a Python object using ForkingPickler. - - Args: - obj: The object to serialize. - output_str (bool): If True, return a base64-encoded string instead of raw bytes. - - Returns: - bytes or str: The serialized object. - """ - buf = io.BytesIO() - ForkingPickler(buf).dump(obj) - buf.seek(0) - output = buf.read() - - if output_str: - # Convert bytes to base64-encoded string - output = pybase64.b64encode(output).decode("utf-8") - - return output - - @staticmethod - def deserialize(data): - """Deserialize a previously serialized object. - - Args: - data (bytes or str): The serialized data, optionally base64-encoded. - - Returns: - The deserialized Python object. - """ - if isinstance(data, str): - # Decode base64 string to bytes - data = pybase64.b64decode(data, validate=True) - - return ForkingPickler.loads(data) - - def run_router(args): try: from sglang_router.launch_router import launch_router diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 24664aedc5..02aea6b3c7 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -554,7 +554,8 @@ def _compute_server_args( "trust_remote_code": True, "random_seed": sglang_cfg["sglang_cfg"]["random_seed"] + rank, # memory - "enable_memory_saver": sglang_cfg["sglang_cfg"]["enable_memory_saver"], + "enable_memory_saver": sglang_cfg["sglang_server"]["needs_offload"], + "enable_weights_cpu_backup": sglang_cfg["sglang_server"]["cpu_weight_backup"], # distributed "host": host, "port": port, diff --git a/nemo_rl/models/generation/sglang/sglang_copied_utils.py b/nemo_rl/models/generation/sglang/sglang_copied_utils.py index aa9eafea01..fe0ee6b8c3 100644 --- a/nemo_rl/models/generation/sglang/sglang_copied_utils.py +++ b/nemo_rl/models/generation/sglang/sglang_copied_utils.py @@ -79,108 +79,4 @@ def deserialize(data): # Decode base64 string to bytes data = pybase64.b64decode(data, validate=True) - return ForkingPickler.loads(data) - - -def monkey_patch_torch_reductions(): # pragma: no cover - """Monkey patch torch multiprocessing reductions to use GPU UUIDs. - - This patch modifies PyTorch's CUDA tensor IPC mechanism to use GPU UUIDs - instead of device indices. This enables proper weight transfer between - processes that may have different CUDA_VISIBLE_DEVICES configurations. - - The patch is idempotent - calling it multiple times is safe. - - This is a workaround before PyTorch https://github.com/pytorch/pytorch/pull/149248 - is merged and released. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L20-L33 - """ - if hasattr(reductions, "_reduce_tensor_original"): - return - - reductions._reduce_tensor_original = reductions.reduce_tensor - reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor - - reductions.reduce_tensor = _reduce_tensor_modified - reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified - - reductions.init_reductions() - - -# The signature has not been changed for years, and we will not need this when -# the next version is released, so it looks safe to use a constant. -# Original source (sglang v0.5.2): -# https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L36 -_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 - - -def _reduce_tensor_modified(*args, **kwargs): # pragma: no cover - """Modified reduce_tensor that stores GPU UUID instead of device index. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L39-L43 - """ - output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) - output_args = _modify_tuple( - output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid - ) - return output_fn, output_args - - -def _rebuild_cuda_tensor_modified(*args): # pragma: no cover - """Modified rebuild_cuda_tensor that accepts GPU UUID or device index. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L46-L48 - """ - args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) - return reductions._rebuild_cuda_tensor_original(*args) - - -def _device_to_uuid(device: int) -> str: # pragma: no cover - """Convert a device index to its UUID string. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L51-L52 - """ - return str(torch.cuda.get_device_properties(device).uuid) - - -def _device_from_maybe_uuid( - device_maybe_uuid: Union[int, str], -) -> int: # pragma: no cover - """Convert a device UUID string or index to a device index. - - Args: - device_maybe_uuid: Either an integer device index or a UUID string. - - Returns: - The integer device index. - - Raises: - Exception: If the UUID doesn't match any available device. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L55-L65 - """ - if isinstance(device_maybe_uuid, int): - return device_maybe_uuid - - if isinstance(device_maybe_uuid, str): - for device in range(torch.cuda.device_count()): - if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: - return device - raise Exception("Invalid device_uuid=" + device_maybe_uuid) - - raise Exception(f"Unknown type: {device_maybe_uuid=}") - - -def _modify_tuple(t, index: int, modifier: Callable): # pragma: no cover - """Create a new tuple with one element modified by a function. - - Original source (sglang v0.5.2): - https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/patch_torch.py#L68-L69 - """ - return *t[:index], modifier(t[index]), *t[index + 1 :] + return ForkingPickler.loads(data) \ No newline at end of file diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index f6facfc748..825d874341 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -189,12 +189,15 @@ def stream_weights_via_ipc_zmq( pass def stream_weights_via_http( - self, sglang_url_to_gpu_uuids: dict[str, list[str]] + self, + rollout_engines, + num_gpus_per_engine: int, ) -> list[ray.ObjectRef]: - """Stream model weights to SGLang servers via HTTP API. + """Stream model weights to colocated SGLang engines via CUDA IPC over HTTP. Args: - sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + rollout_engines: Ray actor handles for the SGLang generation workers. + num_gpus_per_engine: TP size per SGLang engine. """ raise NotImplementedError( "stream_weights_via_http is not implemented for this policy worker" diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index e21bd6dac6..fc6fd614b6 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -883,16 +883,21 @@ def stream_weights_via_ipc_zmq( return futures def stream_weights_via_http( - self, sglang_url_to_gpu_uuids: dict[str, list[str]] + self, + rollout_engines, + num_gpus_per_engine: int, ) -> list[ray.ObjectRef]: - """Send the weights to SGLang servers via HTTP API. + """Send the weights to colocated SGLang engines via CUDA IPC over HTTP. Args: - sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + rollout_engines: Ray actor handles for SGLang generation workers + (one per engine on this node). + num_gpus_per_engine: TP size per SGLang engine. """ futures = self.worker_group.run_all_workers_single_data( "stream_weights_via_http", - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + rollout_engines=rollout_engines, + num_gpus_per_engine=num_gpus_per_engine, ) return futures diff --git a/nemo_rl/models/policy/redesign_utils.py b/nemo_rl/models/policy/redesign_utils.py new file mode 100644 index 0000000000..728e0607c3 --- /dev/null +++ b/nemo_rl/models/policy/redesign_utils.py @@ -0,0 +1,272 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import io +from typing import Callable, Union + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from torch.multiprocessing import reductions +from multiprocessing.reduction import ForkingPickler +import pybase64 + +from packaging import version as pkg_version +torch_release = pkg_version.parse(torch.__version__).release + +SGLANG_TP_RANK = None + +class MultiprocessingSerializer: # pragma: no cover + """Serialize/deserialize Python objects using ForkingPickler for IPC. + + This class enables serialization of objects (including CUDA tensors with IPC + handles) for transfer between processes via HTTP or other mechanisms. + + Original source (sglang v0.5.2): + https://github.com/sgl-project/sglang/blob/v0.5.2/python/sglang/srt/utils.py#L589-L623 + """ + + @staticmethod + def serialize(obj, output_str: bool = False): + """Serialize a Python object using ForkingPickler. + + Args: + obj: The object to serialize. + output_str (bool): If True, return a base64-encoded string instead of raw bytes. + + Returns: + bytes or str: The serialized object. + """ + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + output = buf.read() + + if output_str: + # Convert bytes to base64-encoded string + output = pybase64.b64encode(output).decode("utf-8") + + return output + + @staticmethod + def deserialize(data): + """Deserialize a previously serialized object. + + Args: + data (bytes or str): The serialized data, optionally base64-encoded. + + Returns: + The deserialized Python object. + """ + if isinstance(data, str): + # Decode base64 string to bytes + data = pybase64.b64decode(data, validate=True) + + return ForkingPickler.loads(data) + + +def monkey_patch_torch_reductions(): + """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed""" + + if hasattr(reductions, "_reduce_tensor_original"): + return + reductions._reduce_tensor_original = reductions.reduce_tensor + reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor + + reductions.reduce_tensor = _reduce_tensor_modified + reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified + reductions.init_reductions() + + + +# The signature has not been changed for years, and we will not need this when the next version is released, +# so it looks safe to use a constant. +_REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 + + +def register_sgl_tp_rank(rank: int): + global SGLANG_TP_RANK + SGLANG_TP_RANK = rank + + +def _reduce_tensor_modified(*args, **kwargs): + output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) + output_args = _modify_tuple( + output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid + ) + return output_fn, output_args + + +def _rebuild_cuda_tensor_modified(*args): + args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid) + return reductions._rebuild_cuda_tensor_original(*args) + + +def _device_to_uuid(device: int) -> str: + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int: + if isinstance(device_maybe_uuid, int): + return device_maybe_uuid + + if isinstance(device_maybe_uuid, str): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid: + return device + raise Exception("Invalid device_uuid=" + device_maybe_uuid) + + raise Exception(f"Unknown type: {device_maybe_uuid=}") + + +def _modify_tuple(t, index: int, modifier: Callable): + return *t[:index], modifier(t[index]), *t[index + 1 :] + + +def monkey_patch_torch_compile(): + if torch_release < (2, 8): + # These things are cacheable by torch.compile. torch.compile just doesn't know it. + # This was fixed in PyTorch 2.8, but until then, we monkey patch. + import torch._higher_order_ops.auto_functionalize as af + + af.auto_functionalized_v2._cacheable = True + af.auto_functionalized._cacheable = True + + +def register_fake_if_exists(op_name): + """ + Decorator factory to conditionally register a fake for a custom op if it exists. + Parses op_name (e.g., 'sgl_kernel::gptq_gemm'), checks if the op exists via hasattr + on the namespace attribute of torch.ops. Registers the fake if present; otherwise, + returns the function unchanged. + Args: + op_name (str): Full operator name (e.g., 'sgl_kernel::gptq_gemm'). + Returns: + callable: Decorator for the fake function. + Example: + @register_fake_if_exists('sgl_kernel::gptq_gemm') + def fake_gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit): + return a.new_empty((a.shape[0], b_q_weight.shape[-1]), dtype=a.dtype) + """ + + def decorator(func): + namespace, bare_op = op_name.split("::") + ops_namespace = getattr(torch.ops, namespace, None) + if ops_namespace and hasattr(ops_namespace, bare_op): + torch.library.register_fake(op_name, func) + return func + + return decorator + + +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + # This field is solely for users of to check whether the class supports this feature + supports_multi_dtypes = True + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten().view(torch.uint8) + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError( + "Must provide either named_tensors or both flattened_tensor and metadata" + ) + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = ( + self.flattened_tensor[meta.start_idx : meta.end_idx] + .view(meta.dtype) + .reshape(meta.shape) + ) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 02eb9e6ff8..b174c37c29 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -381,125 +381,198 @@ def rebuild_cuda_tensor_from_ipc( return func(*list_args) -def stream_weights_via_http_impl( - params_generator, - sglang_url_to_gpu_uuids: dict[str, list[str]], - rank: int, - worker_name: str, - current_device_uuid: str, +def _ensure_ipc_topology( + rollout_engines, + num_gpus_per_engine: int, + worker_state: dict, + monkey_patch_fn, ) -> None: - """Stream weights to SGLang servers via HTTP API (update_weights_from_tensor). + """Lazily create a per-engine Gloo subgroup and cache routing state. - Flow: Each rank creates IPC handler → gather handlers in rank order → send list → SGLang matches by tp_rank index - - Key points: - - Each rank creates handler on its own GPU - - Handlers are gathered in rank order: [rank0_handler, rank1_handler, ...] - - List index = rank = GPU ID - - SGLang automatically matches: handler = serialized_handlers[tp_rank] - - Args: - params_generator: Generator yielding (name, tensor) pairs - sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses - rank: Worker rank for logging - worker_name: Name of the worker for logging - current_device_uuid: UUID of the current training worker's GPU + Every FSDP rank must call ``dist.new_group`` for every engine's rank range + (collective). Only the ranks inside a given range stash the group/engine + handles into ``worker_state``. """ - from nemo_rl.models.generation.redesign.misc import ( - MultiprocessingSerializer, - ) + if worker_state.get("ready"): + return - print("[sglang refit details] entering stream_weights_via_http_impl") + monkey_patch_fn() - target_urls = [ - url - for url, uuids in sglang_url_to_gpu_uuids.items() - if current_device_uuid in uuids + my_rank = dist.get_rank() + for i, engine in enumerate(rollout_engines): + start = i * num_gpus_per_engine + group_ranks = list(range(start, start + num_gpus_per_engine)) + grp = dist.new_group(ranks=group_ranks, backend="gloo") + if my_rank in group_ranks: + worker_state["gather_src"] = start + worker_state["gather_group"] = grp + worker_state["engine"] = engine + worker_state["tp_rank"] = my_rank - start + + worker_state.setdefault("weight_version", 0) + worker_state["ready"] = True + + +def _flush_bucket( + named_tensors, + gather_src: int, + gather_group, + engine, + weight_version: int, + flattened_tensor_bucket_cls, + multiprocessing_serializer_cls, +) -> None: + """Flatten ``named_tensors`` per dtype, gather to ``gather_src``, and RPC to engine.""" + import ray + + # Wait on any async DTensor redistributes. + named_tensors = [ + (n, (t.wait() if hasattr(t, "wait") else t)) for n, t in named_tensors ] - if not target_urls: - raise RuntimeError( - f"{worker_name} (rank {rank}): No matching SGLang server found for GPU UUID {current_device_uuid}. " - f"Available servers: {list(sglang_url_to_gpu_uuids.keys())}" + by_dtype: dict = {} + for n, t in named_tensors: + by_dtype.setdefault(t.dtype, []).append((n, t)) + + serialized: list[str] = [] + for _dtype, tensors in by_dtype.items(): + bkt = flattened_tensor_bucket_cls(named_tensors=tensors) + payload = { + "flattened_tensor": bkt.get_flattened_tensor(), + "metadata": bkt.get_metadata(), + } + serialized.append( + multiprocessing_serializer_cls.serialize(payload, output_str=True) ) - if len(target_urls) > 1: - print( - f"[WARNING] {worker_name} (rank {rank}): GPU UUID {current_device_uuid} matches multiple SGLang servers: {target_urls}. " - f"Using the first one: {target_urls[0]}" - ) - target_urls = [target_urls[0]] + my_rank = dist.get_rank() + group_world = dist.get_world_size(gather_group) + gathered = [None] * group_world if my_rank == gather_src else None + dist.gather_object( + serialized, + object_gather_list=gathered, + dst=gather_src, + group=gather_group, + ) - base_url = target_urls[0] - url = f"{base_url}/update_weights_from_tensor" - sglang_gpu_uuids = sglang_url_to_gpu_uuids[base_url] + if my_rank != gather_src: + return + + num_dtypes = len(gathered[0]) + assert num_dtypes > 0 + for i in range(num_dtypes): + ref = engine.update_weights_from_tensor.remote( + serialized_named_tensors=[g[i] for g in gathered], + load_format="flattened_bucket", + flush_cache=False, + weight_version=str(weight_version), + ) + result = ray.get(ref) + if isinstance(result, dict): + success = result.get("success", True) + error_msg = ( + result.get("error_message") or result.get("message", "unknown error") + ) + else: + success = getattr(result, "success", True) + error_msg = getattr(result, "error_message", "unknown error") + if not success: + raise RuntimeError( + f"Weight sync failed on rollout engine: {error_msg}. " + f"Check SGLang version compatibility." + ) - ipc_gather_group, ipc_gather_src, matching_ranks = _setup_ipc_gather_group( - rank, current_device_uuid, sglang_gpu_uuids, sglang_url_to_gpu_uuids - ) - print( - f"[sglang refit] {worker_name} (rank {rank}): ipc_gather_group={ipc_gather_group}, ipc_gather_src={ipc_gather_src}, matching_ranks={matching_ranks}" - ) - tensor_count = 0 - try: - tensor_list = list(params_generator) - total_tensors = len(tensor_list) +def stream_weights_via_http_impl( + model: torch.nn.Module, + rollout_engines, + num_gpus_per_engine: int, + rank: int, + world_size: int, + worker_name: str, + buffer_size_bytes: int, + worker_state: dict, +) -> None: + """Stream FSDP weights to colocated SGLang engines via CUDA IPC over HTTP. - if rank == ipc_gather_src: - print( - f"[sglang refit details] {worker_name}: Starting weight update - " - f"Total parameters to update: {total_tensors}", - flush=True, - ) + Implementation mirrors miles' ``UpdateWeightFromTensor``: size-bounded + buckets over ``model.state_dict()``, dtype-grouped ``FlattenedTensorBucket`` + per bucket, per-engine Gloo subgroup gather to a source rank, and a single + Ray ``.remote()`` per (bucket, dtype) to the colocated SGLang worker. - for idx, (name, tensor) in enumerate(tensor_list): - torch.cuda.current_stream().synchronize() - tensor = tensor.contiguous().cuda() + Args: + model: The FSDP-wrapped training model. + rollout_engines: Ray actor handles for SGLang generation workers. + num_gpus_per_engine: TP size per SGLang engine. + rank: Global FSDP rank. + world_size: Global FSDP world size. + worker_name: Human label for logs. + buffer_size_bytes: Max bucket size in bytes. + worker_state: Mutable dict on the worker used to cache topology and + weight version across refits. + """ + import ray + from torch.distributed.tensor import DTensor, Replicate - named_tensors = [(name, tensor)] - serialized_handler = MultiprocessingSerializer.serialize( - named_tensors, output_str=True - ) - # output_str=True ensures the return type is str - serialized_handler_str = cast(str, serialized_handler) - - gathered_handlers = _gather_ipc_handlers( - serialized_handler_str, - ipc_gather_group, - ipc_gather_src, - rank, - matching_ranks, - ) + from nemo_rl.models.policy.redesign_utils import ( + FlattenedTensorBucket, + MultiprocessingSerializer, + monkey_patch_torch_reductions, + ) - if rank == ipc_gather_src and gathered_handlers is not None: - _send_tensor_to_sglang( - url, - name, - gathered_handlers, - tensor.shape, - str(tensor.dtype), - flush_cache=False, - ) - tensor_count += 1 + _ensure_ipc_topology( + rollout_engines=rollout_engines, + num_gpus_per_engine=num_gpus_per_engine, + worker_state=worker_state, + monkey_patch_fn=monkey_patch_torch_reductions, + ) - del tensor, serialized_handler - if rank == ipc_gather_src: - del gathered_handlers - torch.cuda.empty_cache() + worker_state["weight_version"] = worker_state.get("weight_version", 0) + 1 + weight_version = worker_state["weight_version"] + gather_src = worker_state["gather_src"] + gather_group = worker_state["gather_group"] + engine = worker_state["engine"] - if rank == ipc_gather_src: - print( - f"[sglang refit details] {worker_name}: Weight update completed - " - f"Successfully updated {tensor_count}/{total_tensors} parameters to SGLang server: {base_url}", - flush=True, - ) - if tensor_count != total_tensors: - print( - f"[sglang refit details] {worker_name}: WARNING - Expected {total_tensors} tensors, " - f"but only sent {tensor_count}", - flush=True, + try: + bucket: list = [] + bucket_size = 0 + for name, param in model.state_dict().items(): + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= buffer_size_bytes: + _flush_bucket( + bucket, + gather_src=gather_src, + gather_group=gather_group, + engine=engine, + weight_version=weight_version, + flattened_tensor_bucket_cls=FlattenedTensorBucket, + multiprocessing_serializer_cls=MultiprocessingSerializer, ) + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + _flush_bucket( + bucket, + gather_src=gather_src, + gather_group=gather_group, + engine=engine, + weight_version=weight_version, + flattened_tensor_bucket_cls=FlattenedTensorBucket, + multiprocessing_serializer_cls=MultiprocessingSerializer, + ) + + if dist.get_rank() == gather_src: + ray.get(engine.flush_cache.remote()) except Exception as e: print( @@ -507,132 +580,6 @@ def stream_weights_via_http_impl( f"{traceback.format_exc()}" ) raise - finally: gc.collect() torch.cuda.empty_cache() - - -def _setup_ipc_gather_group( - rank: int, - current_device_uuid: str, - sglang_gpu_uuids: list[str], - sglang_url_to_gpu_uuids: dict[str, list[str]], -) -> tuple[Optional[dist.ProcessGroup], Optional[int], Optional[list[int]]]: - """Setup gather configuration for IPC handlers. - - Returns: - Tuple of (gather_group, gather_src_rank, matching_ranks) - - gather_group: None (use default FSDP group) - - gather_src_rank: The rank that will collect and send to SGLang server - - matching_ranks: List of ranks that belong to the same SGLang server - """ - if not dist.is_initialized(): - return None, None, None - - world_size = dist.get_world_size() - my_rank = dist.get_rank() - - all_ranks_uuids = [None] * world_size - dist.all_gather_object(all_ranks_uuids, current_device_uuid) - - matching_ranks = [ - r for r, uuid in enumerate(all_ranks_uuids) if uuid in sglang_gpu_uuids - ] - - if len(matching_ranks) == 0: - return None, None, None - - matching_ranks = sorted(matching_ranks) - gather_src = matching_ranks[0] - - return None, gather_src, matching_ranks - - -def _gather_ipc_handlers( - serialized_handler: str, - gather_group: Optional[dist.ProcessGroup], - gather_src: Optional[int], - rank: int, - matching_ranks: Optional[list[int]] = None, -) -> Optional[list[str]]: - """Gather IPC handlers from all ranks in the default FSDP group, then filter by server. - - Args: - serialized_handler: Serialized IPC handler from this rank - gather_group: Process group (None means use default FSDP group) - gather_src: Rank that will collect and filter handlers - rank: Current rank - matching_ranks: List of ranks that belong to the same SGLang server - - Returns: - List of serialized handlers in rank order (only on gather_src rank), None otherwise - The list contains handlers from matching_ranks only, in rank order - """ - if gather_src is None: - return None - - if not dist.is_initialized(): - return None - - world_size = dist.get_world_size() - - all_handlers: list[Optional[str]] = [None for _ in range(world_size)] - dist.all_gather_object(all_handlers, serialized_handler) - all_handlers_str = cast(list[str], all_handlers) - - if rank == gather_src and matching_ranks is not None: - filtered_handlers: list[str] = [all_handlers_str[r] for r in matching_ranks] - return filtered_handlers - else: - return None - - -def _send_tensor_to_sglang( - url: str, - tensor_name: str, - gathered_handlers: list[str], - shape: torch.Size, - dtype: str, - flush_cache: bool = False, -) -> None: - """Send gathered IPC handlers to SGLang server via HTTP. - - Key: gathered_handlers are in rank order [rank0, rank1, ...] - SGLang will automatically match: handler = serialized_handlers[tp_rank] - - Args: - url: SGLang server URL - tensor_name: Name of the tensor - gathered_handlers: List of serialized IPC handlers in rank order - shape: Tensor shape - dtype: Tensor dtype - flush_cache: Whether to flush cache after this tensor (for last tensor) - """ - payload = { - "serialized_named_tensors": gathered_handlers, - "flush_cache": flush_cache, - } - - try: - response = requests.post( - url, - json=payload, - headers={"Content-Type": "application/json"}, - timeout=120, - ) - response.raise_for_status() - except requests.exceptions.HTTPError as e: - error_msg = f"Failed to send tensor '{tensor_name}' to {url}: {e}" - try: - error_detail = response.text - error_msg += f"\nResponse status: {response.status_code}" - error_msg += f"\nResponse body: {error_detail[:500]}" - except: - pass - print(f"[sglang refit] {error_msg}", flush=True) - raise RuntimeError(error_msg) from e - except Exception as e: - raise RuntimeError( - f"Failed to send tensor '{tensor_name}' to {url}: {e}" - ) from e diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 2fa8a8e604..2d99ee4448 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -911,12 +911,18 @@ def stream_weights_via_ipc_zmq( @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_http") def stream_weights_via_http( self, - sglang_url_to_gpu_uuids: dict[str, list[str]], + rollout_engines, + num_gpus_per_engine: int, + buffer_size_bytes: int = 512 * 1024 * 1024, ) -> None: - """Stream model weights to SGLang servers via HTTP API. + """Stream FSDP weights to colocated SGLang engines via CUDA IPC over HTTP. Args: - sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses + rollout_engines: Ray actor handles for SGLang generation workers, + one per engine on this node. + num_gpus_per_engine: TP size per SGLang engine. Engine ``i`` is + assumed to own global ranks ``[i*K, (i+1)*K)``. + buffer_size_bytes: Max bucket size in bytes before flushing. """ # Manually move model to cuda for cpu offload case if self.cpu_offload: @@ -924,34 +930,18 @@ def stream_weights_via_http( from nemo_rl.models.policy.utils import stream_weights_via_http_impl - # Get current GPU UUID - current_device_uuid = self.report_device_id() + if not hasattr(self, "_ipc_worker_state"): + self._ipc_worker_state: dict = {} - def dtensor_params_generator(): - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" - state_dict_items = sorted( - self.model.state_dict().items(), key=lambda x: x[0] - ) - for name, tensor in state_dict_items: - if isinstance(tensor, DTensor): - # Convert DTensor to full tensor for streaming - full_tensor = tensor.full_tensor() - # Convert to target dtype - yield ( - name, - full_tensor.to(self.dtype, non_blocking=True).contiguous(), - ) - else: - # Convert to target dtype - yield name, tensor.to(self.dtype, non_blocking=True).contiguous() - - # Use the HTTP implementation stream_weights_via_http_impl( - params_generator=dtensor_params_generator(), - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + model=self.model, + rollout_engines=rollout_engines, + num_gpus_per_engine=num_gpus_per_engine, rank=self.rank, + world_size=torch.distributed.get_world_size(), worker_name=str(self), - current_device_uuid=current_device_uuid, + buffer_size_bytes=buffer_size_bytes, + worker_state=self._ipc_worker_state, ) @torch.no_grad() diff --git a/tests/unit/models/generation/redesign/helpers.py b/tests/unit/models/generation/redesign/helpers.py index 0ed09d3039..cfb3db8bac 100644 --- a/tests/unit/models/generation/redesign/helpers.py +++ b/tests/unit/models/generation/redesign/helpers.py @@ -61,7 +61,6 @@ def make_sglang_cfg( "sglang_cfg": { "model_path": model_path, "random_seed": 42, - "enable_memory_saver": False, "dp_size": 1, "pp_size": 1, "ep_size": 1, diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py index 3a5a8a4116..7560ac066f 100644 --- a/tests/unit/models/generation/redesign/test_sglang_generation.py +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -80,7 +80,6 @@ def _make_sglang_generation_cfg(pad_token_id=PAD_TOKEN_ID, tp_size=1): "context_length": 1024, "log_level": "warning", "skip_server_warmup": True, - "enable_memory_saver": False, "dp_size": 1, "pp_size": 1, "ep_size": 1, @@ -90,6 +89,7 @@ def _make_sglang_generation_cfg(pad_token_id=PAD_TOKEN_ID, tp_size=1): "num_gpus": 4, "num_gpus_per_engine": tp_size, "needs_offload": True, + "cpu_weight_backup": True, "sglang_server_concurrency": 64, }, "sglang_router": { diff --git a/tests/unit/models/generation/redesign/test_utils_smoke.py b/tests/unit/models/generation/redesign/test_utils_smoke.py index cedeecdb35..ec09af5356 100644 --- a/tests/unit/models/generation/redesign/test_utils_smoke.py +++ b/tests/unit/models/generation/redesign/test_utils_smoke.py @@ -23,9 +23,11 @@ from nemo_rl.models.generation.redesign.async_utils import AsyncLoopThread from nemo_rl.models.generation.redesign.misc import ( - MultiprocessingSerializer, terminate_process, ) +from nemo_rl.models.policy.redesign_utils import ( + MultiprocessingSerializer, +) from nemo_rl.models.generation.redesign.ray_utils import ( _wrap_ipv6, find_available_port, diff --git a/tests/unit/models/generation/redesign/test_weight_update_real.py b/tests/unit/models/generation/redesign/test_weight_update_real.py new file mode 100644 index 0000000000..8bb05764b2 --- /dev/null +++ b/tests/unit/models/generation/redesign/test_weight_update_real.py @@ -0,0 +1,473 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end weight update tests using SGLangGeneration + mock FSDP trainer. + +Verifies the full weight-streaming path: + 1. SGLangGeneration.check_weights("snapshot") — save original weights + 2. SGLangGeneration.check_weights("reset_tensors") — randomize weights + 3. Mock FSDP trainer streams Qwen3-1.7B weights via stream_weights_via_http_impl + 4. SGLangGeneration.check_weights("compare") — verify restored weights + +Parametrised over two configurations (both require 4 GPUs): + • 1 server × TP=4 — single-server high-TP + • 2 servers × TP=2 — multi-server routing + +Model: Qwen/Qwen3-1.7B +""" + +import gc +import os + +import pytest +import ray +import torch +import torch.distributed as dist + +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.models.generation.redesign.ray_utils import find_available_port, get_host_info +from nemo_rl.models.generation.redesign.sglang_generation import SGLangGeneration + +from helpers import make_actor_env_vars, post_and_assert_200 + +pytestmark = pytest.mark.sglang + +MODEL_PATH = "Qwen/Qwen3-1.7B" + + +# --------------------------------------------------------------------------- +# SGLang config builder +# --------------------------------------------------------------------------- +def _make_sglang_cfg(tp_size): + return { + "sglang_cfg": { + "model_path": MODEL_PATH, + "dtype": "bfloat16", + "random_seed": 42, + "context_length": 1024, + "log_level": "warning", + "skip_server_warmup": True, + "dp_size": 1, + "pp_size": 1, + "ep_size": 1, + "disable_piecewise_cuda_graph": True, + }, + "sglang_server": { + "num_gpus": 4, + "num_gpus_per_engine": tp_size, + "needs_offload": True, + "sglang_server_concurrency": 64, + }, + "sglang_router": { + "sglang_router_ip": None, + "sglang_router_port": None, + }, + } + + +# --------------------------------------------------------------------------- +# Mock FSDP trainer worker +# --------------------------------------------------------------------------- +@ray.remote(num_cpus=0.1) +class MockFSDPWorker: + """Simulates one FSDP rank for weight streaming. + + Loads the full model on a single GPU and calls the real + ``stream_weights_via_http_impl`` to send weights to SGLang servers. + """ + + def init(self, rank, world_size, master_addr, master_port, model_path, gpu_index): + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(gpu_index) + os.environ["WORLD_SIZE"] = str(world_size) + + self.rank = rank + self.gpu_index = gpu_index + + torch.cuda.set_device(gpu_index) + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + + from transformers import AutoModelForCausalLM + + device = torch.device(f"cuda:{gpu_index}") + self.model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, + ).to(device) + + from nemo_rl.utils.nvml import get_device_uuid + + self.device_uuid = get_device_uuid(gpu_index) + + def get_device_uuid(self): + return self.device_uuid + + def stream_weights(self, rollout_engines, num_gpus_per_engine): + from nemo_rl.models.policy.utils import stream_weights_via_http_impl + + if not hasattr(self, "_ipc_worker_state"): + self._ipc_worker_state = {} + + stream_weights_via_http_impl( + model=self.model, + rollout_engines=rollout_engines, + num_gpus_per_engine=num_gpus_per_engine, + rank=self.rank, + world_size=dist.get_world_size(), + worker_name=f"MockFSDPWorker-{self.rank}", + buffer_size_bytes=512 * 1024 * 1024, + worker_state=self._ipc_worker_state, + ) + + def shutdown(self): + if dist.is_initialized(): + dist.destroy_process_group() + if hasattr(self, "model"): + del self.model + gc.collect() + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture( + params=[ + pytest.param({"tp_size": 4, "num_servers": 1}, id="tp4_1server"), + pytest.param({"tp_size": 2, "num_servers": 2}, id="tp2_2servers"), + ] +) +def sglang_gen(request, ray_cluster): + """Real SGLangGeneration: RayVirtualCluster → router → engines.""" + cfg = request.param + tp_size = cfg["tp_size"] + + cluster = RayVirtualCluster( + bundle_ct_per_node_list=[4], + use_gpus=True, + max_colocated_worker_groups=2, + num_gpus_per_node=4, + name="weight-update-test", + ) + cluster_cfg = {"gpus_per_node": 4, "num_nodes": 1} + sglang_cfg = _make_sglang_cfg(tp_size) + + gen = SGLangGeneration(cluster, cluster_cfg, sglang_cfg) + yield gen + + try: + gen.shutdown() + except Exception: + pass + try: + cluster.shutdown() + except Exception: + pass + gc.collect() + torch.cuda.empty_cache() + + +@pytest.fixture +def mock_trainer(ray_cluster, sglang_gen): + """4 MockFSDPWorker actors with torch.distributed (gloo), each loading Qwen3-1.7B. + + Actors are launched into the SGLang cluster's placement group using + PlacementGroupSchedulingStrategy with fractional GPU (num_gpus=0.2), so + they co-reside with the SGLang worker (which also takes 0.2) on the same + bundles. This matches the nemo_rl colocated-mode and miles patterns; the + PG's bundles have ``CPU: max_colocated_worker_groups`` capacity (=2) to + fit both worker groups. + """ + host_ip = get_host_info()[1] + master_port = find_available_port(29500) + env_vars = make_actor_env_vars() + + pg = sglang_gen.cluster.get_placement_groups()[0] + + workers = [] + for rank in range(4): + w = MockFSDPWorker.options( + num_cpus=0.2, + num_gpus=0.2, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=rank, + ), + runtime_env={"env_vars": env_vars}, + ).remote() + workers.append(w) + + # All workers must init simultaneously (gloo rendezvous). + ray.get([ + w.init.remote( + rank=rank, + world_size=4, + master_addr=host_ip, + master_port=master_port, + model_path=MODEL_PATH, + gpu_index=rank, + ) + for rank, w in enumerate(workers) + ]) + + yield workers + + for w in workers: + try: + ray.get(w.shutdown.remote()) + except Exception: + pass + ray.kill(w) + gc.collect() + torch.cuda.empty_cache() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +def test_weight_update_roundtrip(sglang_gen, mock_trainer): + """Snapshot -> reset -> offload -> update -> compare -> onload_kv. + + Exercises the full colocated-refit memory dance: + snapshot -> reset_tensors -> offload_weights -> offload_kv -> + onload_weights -> update_weights (stream) -> check compare -> + onload_kv. + """ + # 1. Snapshot original Qwen3-1.7B weights. + print("[STEP 1/7] Snapshotting original weights...", flush=True) + sglang_gen.check_weights("snapshot") + print("[STEP 1/7] Snapshot complete.", flush=True) + + # 2. Randomize all model weights on the SGLang servers. + print("[STEP 2/7] Randomizing (reset_tensors) model weights...", flush=True) + sglang_gen.check_weights("reset_tensors") + print("[STEP 2/7] Reset complete.", flush=True) + + # 3. Offload weights and KV cache to CPU (refit prelude). + print("[STEP 3/7] Offloading weights and KV cache to CPU...", flush=True) + sglang_gen.offload_weights() + sglang_gen.offload_kv() + print("[STEP 3/7] Offload complete.", flush=True) + + # 4. Onload weight buffers back to GPU so IPC handles can target them. + print("[STEP 4/7] Onloading weight buffers back to GPU...", flush=True) + sglang_gen.onload_weights() + print("[STEP 4/7] Onload weights complete.", flush=True) + + # 5. All 4 mock FSDP workers stream weights simultaneously via CUDA IPC over HTTP. + print("[STEP 5/7] Streaming weights from mock FSDP workers via CUDA IPC...", flush=True) + rollout_engines = sglang_gen.rollout_engines + num_gpus_per_engine = sglang_gen.num_gpus_per_engine + ray.get([ + w.stream_weights.remote(rollout_engines, num_gpus_per_engine) + for w in mock_trainer + ]) + print("[STEP 5/7] Weight streaming complete.", flush=True) + + # 6. Compare current weights against snapshot - raises on mismatch. + print("[STEP 6/7] Comparing current weights against snapshot...", flush=True) + sglang_gen.check_weights("compare") + print("[STEP 6/7] Compare passed.", flush=True) + + # 7. Onload KV cache to finish the refit cycle. + print("[STEP 7/7] Onloading KV cache to finish refit cycle...", flush=True) + sglang_gen.onload_kv() + print("[STEP 7/7] Roundtrip complete.", flush=True) + + +# --------------------------------------------------------------------------- +# Test: roundtrip + router-based generate() + greedy before/after comparison +# --------------------------------------------------------------------------- +def test_weight_update_roundtrip_with_router_generation(sglang_gen, mock_trainer): + """Full refit roundtrip with generation via router and per-worker HTTP 200 checks. + + Differs from ``test_weight_update_roundtrip`` in two ways: + + 1. *Generation through the router.* Both the pre-snapshot and post-onload_kv + generations go through ``sglang_gen.generate(..., greedy=True)`` which + calls ``generate_one_sample(router_ip, router_port, ...)`` — i.e. an + HTTP POST to ``http://{router_ip}:{router_port}/generate``, not to + any individual server. (``sglang_gen.generate`` internally calls + ``resp.raise_for_status()`` so a successful return implies HTTP 200.) + Parametrised over ``tp4_1server`` and ``tp2_2servers``; both configs + share the same router, so the same generation path is exercised in + both. + 2. *Per-worker HTTP 200 checks for the refit cycle.* Instead of calling + ``sglang_gen.check_weights(...)`` / ``offload_weights`` / ``offload_kv`` / + ``onload_weights`` / ``onload_kv``, this test iterates + ``sglang_gen.engines`` and drives the equivalent HTTP endpoints on + **every worker** directly via ``post_and_assert_200`` (same pattern as + ``tests/unit/models/generation/redesign/test_sglang_worker_memory.py``). + That way every single memory/weights transition is verified to return + ``resp.status_code == 200`` — ``_make_request`` would hide the status + behind ``raise_for_status()``. + + Strict outer check: with ``temperature=0.0`` the mock FSDP trainer streams + the original Qwen3-1.7B weights back, so pre- and post-roundtrip greedy + token sequences must match exactly. + """ + from transformers import AutoTokenizer + + from nemo_rl.distributed.batched_data_dict import BatchedDataDict + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + # Sanity: router endpoint is set so sglang_gen.generate actually routes. + assert sglang_gen.router_ip is not None and sglang_gen.router_port is not None, ( + "router_ip/router_port not set on sglang_gen — generate() would not route" + ) + print( + f"[setup] Router endpoint: http://{sglang_gen.router_ip}:{sglang_gen.router_port}", + flush=True, + ) + + # All logical-engine node-0 actors (one per SGLang server). + engines = [e for e in sglang_gen.engines if e is not None] + assert len(engines) >= 1, "sglang_gen has no engines" + base_urls = ray.get([e.get_base_url.remote() for e in engines]) + assert all(u is not None for u in base_urls), f"missing base_url in {base_urls}" + print(f"[setup] {len(engines)} worker(s); base_urls={base_urls}", flush=True) + + # --- Per-worker HTTP helpers ----------------------------------------------- + def _http_check_weights_all(action: str): + """POST /weights_checker on every worker, asserting 200 each time.""" + for url in base_urls: + post_and_assert_200(url, "weights_checker", {"action": action}) + + def _http_release_weights_all(): + """Flush cache + POST /release_memory_occupation(tags=[weights]) per worker.""" + for engine, url in zip(engines, base_urls): + ray.get(engine.flush_cache.remote()) + post_and_assert_200( + url, "release_memory_occupation", {"tags": ["weights"]} + ) + + def _http_release_kv_all(): + """Flush cache + POST /release_memory_occupation(tags=[kv_cache, cuda_graph]).""" + for engine, url in zip(engines, base_urls): + ray.get(engine.flush_cache.remote()) + post_and_assert_200( + url, + "release_memory_occupation", + {"tags": ["kv_cache", "cuda_graph"]}, + ) + + def _http_resume_weights_all(): + """POST /resume_memory_occupation(tags=[weights]) per worker.""" + for url in base_urls: + post_and_assert_200( + url, "resume_memory_occupation", {"tags": ["weights"]} + ) + + def _http_resume_kv_all(): + """POST /resume_memory_occupation(tags=[kv_cache, cuda_graph]) per worker.""" + for url in base_urls: + post_and_assert_200( + url, + "resume_memory_occupation", + {"tags": ["kv_cache", "cuda_graph"]}, + ) + + # --- Router-based greedy generation --------------------------------------- + test_prompt = "The capital of France is" + input_ids = tokenizer.encode(test_prompt, add_special_tokens=True) + input_len = len(input_ids) + + data = BatchedDataDict( + { + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "input_lengths": torch.tensor([input_len], dtype=torch.long), + } + ) + + def _generate(tag): + result = sglang_gen.generate(data, greedy=True) + for key in ("output_ids", "generation_lengths", "unpadded_sequence_lengths", "logprobs"): + assert key in result, f"[{tag}] generate() output missing key: {key}" + gen_len = int(result["generation_lengths"][0].item()) + assert gen_len > 0, f"[{tag}] generate() returned 0 tokens (no new tokens generated)" + tokens = result["output_ids"][0, input_len : input_len + gen_len].tolist() + assert all(isinstance(t, int) for t in tokens), ( + f"[{tag}] output tokens should be ints, got {tokens!r}" + ) + text = tokenizer.decode(tokens, skip_special_tokens=True) + assert len(text) > 0, f"[{tag}] decoded generated text is empty" + print(f"[{tag}] gen_len={gen_len} tokens={tokens} text={text!r}", flush=True) + return tokens + + # --- Generation BEFORE snapshot (via router) ------------------------------- + print("[PRE] Router greedy generate() before snapshot...", flush=True) + tokens_before = _generate("PRE") + + # --- Steps 1-7, every HTTP call asserted 200 ------------------------------ + print("[STEP 1/7] Snapshotting original weights (HTTP weights_checker×workers)...", flush=True) + _http_check_weights_all("snapshot") + print("[STEP 1/7] Snapshot complete.", flush=True) + + print("[STEP 2/7] Randomizing weights (HTTP weights_checker reset_tensors×workers)...", flush=True) + _http_check_weights_all("reset_tensors") + print("[STEP 2/7] Reset complete.", flush=True) + + print("[STEP 3/7] Offloading weights + KV (HTTP release_memory_occupation×workers)...", flush=True) + _http_release_weights_all() + _http_release_kv_all() + print("[STEP 3/7] Offload complete.", flush=True) + + print("[STEP 4/7] Onloading weights (HTTP resume_memory_occupation weights×workers)...", flush=True) + _http_resume_weights_all() + print("[STEP 4/7] Onload weights complete.", flush=True) + + print("[STEP 5/7] Streaming weights from mock FSDP workers via CUDA IPC...", flush=True) + rollout_engines = sglang_gen.rollout_engines + num_gpus_per_engine = sglang_gen.num_gpus_per_engine + ray.get([ + w.stream_weights.remote(rollout_engines, num_gpus_per_engine) + for w in mock_trainer + ]) + print("[STEP 5/7] Weight streaming complete.", flush=True) + + print("[STEP 6/7] Compare vs snapshot (HTTP weights_checker compare×workers)...", flush=True) + _http_check_weights_all("compare") + print("[STEP 6/7] Compare passed.", flush=True) + + print("[STEP 7/7] Onloading KV (HTTP resume_memory_occupation kv×workers)...", flush=True) + _http_resume_kv_all() + print("[STEP 7/7] Roundtrip complete.", flush=True) + + # --- Generation AFTER onload_kv (via router) ------------------------------- + print("[POST] Router greedy generate() after onload_kv...", flush=True) + tokens_after = _generate("POST") + + # --- Sanity: generate() actually produced new tokens on BOTH runs ---------- + assert len(tokens_before) > 0, "generate() returned no tokens before roundtrip" + assert len(tokens_after) > 0, "generate() returned no tokens after roundtrip" + assert len(tokens_before) == len(tokens_after), ( + f"Different number of generated tokens before vs. after: " + f"before={len(tokens_before)}, after={len(tokens_after)}" + ) + + # --- Strict equality (greedy => deterministic across roundtrip) ------------ + assert tokens_before == tokens_after, ( + "Greedy tokens changed across the refit roundtrip:\n" + f" before (pre-snapshot): {tokens_before}\n" + f" after (post-onload_kv): {tokens_after}" + ) + print( + f"[ASSERT] Greedy tokens match before vs. after roundtrip " + f"(n={len(tokens_before)} tokens, both non-empty, both via router).", + flush=True, + ) From 8528f8e03f25ca2f3bf4d7231a4709a7e9bc81de Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 11:39:54 -0700 Subject: [PATCH 30/36] update --- tests/unit/models/generation/redesign/helpers.py | 1 + .../generation/redesign/test_sglang_worker_memory.py | 10 +++++----- .../generation/redesign/test_weight_update_real.py | 1 + 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unit/models/generation/redesign/helpers.py b/tests/unit/models/generation/redesign/helpers.py index cfb3db8bac..d1d716e2e9 100644 --- a/tests/unit/models/generation/redesign/helpers.py +++ b/tests/unit/models/generation/redesign/helpers.py @@ -74,6 +74,7 @@ def make_sglang_cfg( "num_gpus": num_gpus, "num_gpus_per_engine": tp_size, "needs_offload": True, + "cpu_weight_backup": False, "sglang_server_concurrency": 64, }, "sglang_router": { diff --git a/tests/unit/models/generation/redesign/test_sglang_worker_memory.py b/tests/unit/models/generation/redesign/test_sglang_worker_memory.py index f99b773367..604b6cddcc 100644 --- a/tests/unit/models/generation/redesign/test_sglang_worker_memory.py +++ b/tests/unit/models/generation/redesign/test_sglang_worker_memory.py @@ -16,10 +16,10 @@ flush_cache, release_memory_occupation, resume_memory_occupation. Uses a real SGLang server (Qwen3-0.6B), parametrised over two -configurations so the same tests exercise both a single-rank TP=1 -worker and a TP=2 worker: +configurations so the same tests exercise both a single-worker TP=4 +setup and a two-worker TP=2 setup: - • tp1 — 1 worker × TP=1 + • tp4 — 1 worker × TP=4 • tp2_2workers — 2 workers × TP=2 (the memory tests target worker 0, but both workers share the router) @@ -38,14 +38,14 @@ @pytest.fixture( scope="module", params=[ - pytest.param({"tp_size": 1, "num_workers": 1}, id="tp1"), + pytest.param({"tp_size": 4, "num_workers": 1}, id="tp4"), pytest.param({"tp_size": 2, "num_workers": 2}, id="tp2_2workers"), ], ) def worker(request, ray_cluster, router): """Worker(s) dedicated to memory tests. - For ``tp1`` a single TP=1 worker is created. For ``tp2_2workers`` + For ``tp4`` a single TP=4 worker is created. For ``tp2_2workers`` two TP=2 workers share the same router (mirroring the 2-servers configuration exercised elsewhere); memory tests run against the first worker but the second is kept alive so the router has the diff --git a/tests/unit/models/generation/redesign/test_weight_update_real.py b/tests/unit/models/generation/redesign/test_weight_update_real.py index 8bb05764b2..cda4986a4b 100644 --- a/tests/unit/models/generation/redesign/test_weight_update_real.py +++ b/tests/unit/models/generation/redesign/test_weight_update_real.py @@ -69,6 +69,7 @@ def _make_sglang_cfg(tp_size): "num_gpus": 4, "num_gpus_per_engine": tp_size, "needs_offload": True, + "cpu_weight_backup": False, "sglang_server_concurrency": 64, }, "sglang_router": { From 67f2e16849e960944f971b245c31f3338834431c Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 15:21:33 -0700 Subject: [PATCH 31/36] update --- .../redesign/sglang-nemo-rl-compat.patch | 48 +++++++++++++++++++ nemo_rl/models/policy/redesign_utils.py | 25 +++++++--- .../redesign/test_sglang_generation.py | 3 +- .../redesign/test_weight_update_real.py | 22 +++++++-- 4 files changed, 87 insertions(+), 11 deletions(-) create mode 100644 nemo_rl/models/generation/redesign/sglang-nemo-rl-compat.patch diff --git a/nemo_rl/models/generation/redesign/sglang-nemo-rl-compat.patch b/nemo_rl/models/generation/redesign/sglang-nemo-rl-compat.patch new file mode 100644 index 0000000000..1d8b088901 --- /dev/null +++ b/nemo_rl/models/generation/redesign/sglang-nemo-rl-compat.patch @@ -0,0 +1,48 @@ +# SGLang compatibility patch for NeMo-RL colocated refit +# +# Base: lmsysorg/sglang:dev at commit 2e70e4f +# Target: /sgl-workspace/sglang/ +# Apply: cd /sgl-workspace/sglang && git apply /path/to/this.patch +# +# Two changes: +# +# 1. SafeUnpickler.ALLOWED_MODULE_PREFIXES +# NeMo-RL monkey-patches torch CUDA reductions with helpers that live in +# nemo_rl.models.policy.redesign_utils (_reduce_tensor_modified / +# _rebuild_cuda_tensor_modified -- they translate device index <-> GPU UUID +# so IPC handles survive mismatched CUDA_VISIBLE_DEVICES between trainer +# and SGLang worker). ForkingPickler emits that module path in the pickle +# stream sent to update_weights_from_tensor. SafeUnpickler (CVE-2025-10164 +# mitigation) rejects anything outside its allowlist, so we add the +# single submodule prefix used by NeMo-RL. +# +# 2. WeightChecker._reset_tensors skip list +# cos_sin_cache / freqs_cis / _weight_fp32 are derived/precomputed buffers +# rather than trainable weights; overwriting them with random noise breaks +# attention and prevents the compare step from converging. Mirrors the +# sglang-miles branch: +# https://github.com/sgl-project/sglang/blob/sglang-miles/python/sglang/srt/utils/weight_checker.py + +diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py +--- a/python/sglang/srt/utils/common.py ++++ b/python/sglang/srt/utils/common.py +@@ -2157,6 +2157,7 @@ class SafeUnpickler(pickle.Unpickler): + "sglang.srt.layers.", + "sglang.srt.utils.", + "torch_npu.", ++ "nemo_rl.models.policy.redesign_utils.", + } + + DENY_CLASSES = { +diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py +--- a/python/sglang/srt/utils/weight_checker.py ++++ b/python/sglang/srt/utils/weight_checker.py +@@ -38,6 +38,8 @@ class WeightChecker: + + def _reset_tensors(self): + for name, param in self._model_state(): ++ if "cos_sin_cache" in name or "freqs_cis" in name or "_weight_fp32" in name: ++ continue + param.copy_(_random_like(param)) + + def _compare(self): diff --git a/nemo_rl/models/policy/redesign_utils.py b/nemo_rl/models/policy/redesign_utils.py index 728e0607c3..93fb2b999c 100644 --- a/nemo_rl/models/policy/redesign_utils.py +++ b/nemo_rl/models/policy/redesign_utils.py @@ -25,7 +25,24 @@ from packaging import version as pkg_version torch_release = pkg_version.parse(torch.__version__).release -SGLANG_TP_RANK = None +# /sgl-workspace/sglang/python/sglang/srt/utils/common.py +# class SafeUnpickler(pickle.Unpickler): +# ALLOWED_MODULE_PREFIXES = { +# ... +# "sglang.srt.weight_sync.tensor_bucket.", +# "sglang.srt.model_executor.model_runner.", +# "sglang.srt.layers.", +# "sglang.srt.utils.", +# "torch_npu.", +# + "nemo_rl.models.policy.redesign_utils.", +# } + +# Refer: https://github.com/sgl-project/sglang/blob/sglang-miles/python/sglang/srt/utils/weight_checker.py +# def _reset_tensors(self): +# for name, param in self._model_state(): +# if "cos_sin_cache" in name or "freqs_cis" in name or "_weight_fp32" in name: +# continue +# param.copy_(_random_like(param)) class MultiprocessingSerializer: # pragma: no cover """Serialize/deserialize Python objects using ForkingPickler for IPC. @@ -94,12 +111,6 @@ def monkey_patch_torch_reductions(): # so it looks safe to use a constant. _REDUCE_TENSOR_ARG_DEVICE_INDEX = 6 - -def register_sgl_tp_rank(rank: int): - global SGLANG_TP_RANK - SGLANG_TP_RANK = rank - - def _reduce_tensor_modified(*args, **kwargs): output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs) output_args = _modify_tuple( diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py index 7560ac066f..88862f1b08 100644 --- a/tests/unit/models/generation/redesign/test_sglang_generation.py +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -42,11 +42,12 @@ from nemo_rl.models.generation.interfaces import GenerationDatumSpec from helpers import ( - MODEL_PATH, make_generation_sampling_params, post_and_assert_200, ) +MODEL_PATH = "Qwen/Qwen3-4B" + pytestmark = pytest.mark.sglang # --------------------------------------------------------------------------- diff --git a/tests/unit/models/generation/redesign/test_weight_update_real.py b/tests/unit/models/generation/redesign/test_weight_update_real.py index cda4986a4b..7bea419222 100644 --- a/tests/unit/models/generation/redesign/test_weight_update_real.py +++ b/tests/unit/models/generation/redesign/test_weight_update_real.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""End-to-end weight update tests using SGLangGeneration + mock FSDP trainer. +""" +End-to-end weight update tests using SGLangGeneration + mock FSDP trainer. Verifies the full weight-streaming path: 1. SGLangGeneration.check_weights("snapshot") — save original weights @@ -45,14 +46,28 @@ pytestmark = pytest.mark.sglang -MODEL_PATH = "Qwen/Qwen3-1.7B" +MODEL_PATH = "Qwen/Qwen3-4B" +PAD_TOKEN_ID = 151643 +EOS_TOKEN_ID = 151645 # --------------------------------------------------------------------------- # SGLang config builder # --------------------------------------------------------------------------- -def _make_sglang_cfg(tp_size): +def _make_sglang_cfg(tp_size, pad_token_id=PAD_TOKEN_ID): return { + "backend": "sglang", + "model_name": MODEL_PATH, + "model_path": MODEL_PATH, + "tokenizer": {"name": MODEL_PATH}, + "dtype": "bfloat16", + "max_new_tokens": 16, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": [EOS_TOKEN_ID], + "stop_strings": None, + "_pad_token_id": pad_token_id, "sglang_cfg": { "model_path": MODEL_PATH, "dtype": "bfloat16", @@ -76,6 +91,7 @@ def _make_sglang_cfg(tp_size): "sglang_router_ip": None, "sglang_router_port": None, }, + "sglang_kwargs": {}, } From f31c49daf73e359ba4800905e97edcd8913ff5ff Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 15:21:41 -0700 Subject: [PATCH 32/36] update --- nemo_rl/models/generation/redesign/sglang_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index 38ded88980..c6dfd2ee2d 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -458,7 +458,7 @@ def _build_sampling_params( def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: - """Generate a batch of data using vLLM generation. + """Generate a batch of data using Sglang generation. Args: data: BatchedDataDict containing input_ids and input_lengths tensors From c2ef4f2059e175610f7c50c0095ddc594ba90b00 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 15:51:15 -0700 Subject: [PATCH 33/36] update --- .../generation/redesign/test_sglang_generation.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py index 88862f1b08..c3d7a99095 100644 --- a/tests/unit/models/generation/redesign/test_sglang_generation.py +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -221,20 +221,6 @@ def test_generate_output_ids_shape(sglang_gen, tokenizer): assert result["unpadded_sequence_lengths"][0].item() == input_len + gen_len -def test_generate_produces_nonzero_tokens(sglang_gen, tokenizer): - """Generated tokens are non-zero and non-pad.""" - data = _make_input(tokenizer, "What is 2 plus 2?") - result = sglang_gen.generate(data, greedy=True) - - gen_len = result["generation_lengths"][0].item() - assert gen_len > 0, "No tokens generated" - - input_len = data["input_lengths"][0].item() - generated = result["output_ids"][0, input_len:input_len + gen_len] - assert (generated != 0).all(), "Generated tokens contain zeros" - assert (generated != PAD_TOKEN_ID).all(), "Generated tokens contain pad" - - def test_generate_greedy_determinism(sglang_gen, tokenizer): """Same prompt + greedy=True → identical output_ids across two calls.""" data = _make_input(tokenizer, "Once upon a time") From c114f45a4dc03cacba48f2f8ca7366c617732b69 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Tue, 14 Apr 2026 16:22:07 -0700 Subject: [PATCH 34/36] update --- tests/unit/models/generation/redesign/test_sglang_generation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/models/generation/redesign/test_sglang_generation.py b/tests/unit/models/generation/redesign/test_sglang_generation.py index c3d7a99095..ba5e61c5f3 100644 --- a/tests/unit/models/generation/redesign/test_sglang_generation.py +++ b/tests/unit/models/generation/redesign/test_sglang_generation.py @@ -458,7 +458,6 @@ def test_generate_after_memory_cycle_top_level_api(sglang_gen, tokenizer): gen_len_before = r_before["generation_lengths"][0].item() assert gen_len_before > 0, "generate() before memory cycle produced 0 tokens" tokens_before = r_before["output_ids"][0, input_len : input_len + gen_len_before] - assert (tokens_before != 0).all(), "before: generated tokens contain zeros" assert (tokens_before != PAD_TOKEN_ID).all(), "before: generated tokens contain pad" # Full offload + onload cycle using the top-level SGLangGeneration API. @@ -471,7 +470,6 @@ def test_generate_after_memory_cycle_top_level_api(sglang_gen, tokenizer): gen_len_after = r_after["generation_lengths"][0].item() assert gen_len_after > 0, "generate() after memory cycle produced 0 tokens" tokens_after = r_after["output_ids"][0, input_len : input_len + gen_len_after] - assert (tokens_after != 0).all(), "after: generated tokens contain zeros" assert (tokens_after != PAD_TOKEN_ID).all(), "after: generated tokens contain pad" assert gen_len_before == gen_len_after, ( From 5b684ebd2ead0ec5ffec625d59b883bbacc0bc7c Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Thu, 16 Apr 2026 00:36:18 -0700 Subject: [PATCH 35/36] update --- .../configs/grpo_math_1B_redesign_sglang.yaml | 36 +++++++++ ...nstruct-1n4g-fsdp2tp1-redesign-sglang.yaml | 59 ++++++++++++++ .../generation/redesign/sglang_generation.py | 30 +++++-- .../generation/redesign/sglang_worker.py | 1 + .../llm/grpo-math-1b-redesign-1n4g-sglang.sh | 79 +++++++++++++++++++ ...-instruct-1n4g-fsdp2tp1-redesign-sglang.sh | 45 +++++++++++ 6 files changed, 244 insertions(+), 6 deletions(-) create mode 100644 examples/configs/grpo_math_1B_redesign_sglang.yaml create mode 100644 examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml create mode 100755 tests/test_suites/llm/grpo-math-1b-redesign-1n4g-sglang.sh create mode 100755 tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh diff --git a/examples/configs/grpo_math_1B_redesign_sglang.yaml b/examples/configs/grpo_math_1B_redesign_sglang.yaml new file mode 100644 index 0000000000..eaa025dfd1 --- /dev/null +++ b/examples/configs/grpo_math_1B_redesign_sglang.yaml @@ -0,0 +1,36 @@ +defaults: grpo_math_1B.yaml + +grpo: + val_batch_size: 128 + +policy: + generation: + backend: "sglang" + sglang_cfg: + # SGLang specific configuration + model_path: ${policy.model_name} + dtype: ${policy.precision} + context_length: 512 # Maximum context length + allow_auto_truncate: true + dp_size: 1 + pp_size: 1 + ep_size: 1 + random_seed: 42 + max_running_requests: null + mem_fraction_static: 0.7 + skip_server_warmup: true + # Piecewise CUDA graph currently crashes with "illegal memory access" + # (likely torch 2.10 + sglang incompatibility). Keep disabled until upstream fix. + disable_piecewise_cuda_graph: true + sglang_server: + needs_offload: true + cpu_weight_backup: false + sglang_server_concurrency: 64 + num_gpus: 4 + num_gpus_per_engine: 2 +logger: + wandb_enabled: true + +cluster: + gpus_per_node: 4 + num_nodes: 1 \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml new file mode 100644 index 0000000000..ff3a82e948 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml @@ -0,0 +1,59 @@ +defaults: ../../grpo_math_1B.yaml + +grpo: + max_num_steps: 450 + val_batch_size: 128 + +checkpointing: + checkpoint_dir: results/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang + +policy: + model_name: Qwen/Qwen2.5-Math-1.5B-Instruct + tokenizer: + name: Qwen/Qwen2.5-Math-1.5B-Instruct + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + backend: "sglang" + max_new_tokens: 512 + sglang_cfg: + model_path: ${policy.model_name} + dtype: ${policy.precision} + context_length: 512 + allow_auto_truncate: true + dp_size: 1 + pp_size: 1 + ep_size: 1 + random_seed: 42 + max_running_requests: null + mem_fraction_static: 0.7 + skip_server_warmup: true + # Piecewise CUDA graph currently crashes with "illegal memory access" + # (likely torch 2.10 + sglang incompatibility). Keep disabled until upstream fix. + disable_piecewise_cuda_graph: true + disable_cuda_graph: true + sglang_server: + needs_offload: true + cpu_weight_backup: false + sglang_server_concurrency: 64 + num_gpus: 2 + num_gpus_per_engine: 2 + sglang_router: {} + +data: + max_input_seq_length: 512 + +logger: + log_dir: logs/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang + +cluster: + gpus_per_node: 2 + num_nodes: 1 diff --git a/nemo_rl/models/generation/redesign/sglang_generation.py b/nemo_rl/models/generation/redesign/sglang_generation.py index c6dfd2ee2d..a425689a24 100644 --- a/nemo_rl/models/generation/redesign/sglang_generation.py +++ b/nemo_rl/models/generation/redesign/sglang_generation.py @@ -56,6 +56,7 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan self.cluster = cluster self.cluster_cfg = cluster_cfg self.sglang_cfg = sglang_cfg + self._health_monitor = None pgs = cluster._init_placement_groups( strategy="PACK", @@ -99,11 +100,10 @@ def __init__(self, cluster: RayVirtualCluster, cluster_cfg: ClusterConfig, sglan self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() - monitor = None if sglang_cfg["sglang_cfg"].get("use_fault_tolerance"): monitor = RolloutHealthMonitor(self, sglang_cfg) monitor.start() - self._health_monitor = monitor + self._health_monitor = monitor # ------------------------------------------------------------------ # Engine topology properties (formerly ``ServerGroup``) @@ -796,11 +796,29 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: """Wake workers up for colocated inference.""" - pass + tags = kwargs.get("tags", None) + if self.needs_offload: + if tags is None: + self.onload_weights() + self.onload_kv() + else: + if "weights" in tags: + self.onload_weights() + if "kv_cache" in tags: + self.onload_kv() def finish_generation(self, *args: Any, **kwargs: Any) -> bool: """Sleep workers and reset prefix cache.""" - pass + tags = kwargs.get("tags", None) + if self.needs_offload: + if tags is None: + self.offload_weights() + self.offload_kv() + else: + if "weights" in tags: + self.offload_weights() + if "kv_cache" in tags: + self.offload_kv() def invalidate_kv_cache(self) -> bool: """Invalidate KV cache before weight updates (Megatron-style). @@ -1029,8 +1047,8 @@ def _start_router( ``actor_handle=None`` (we do not own that router and must not terminate it). Otherwise spawn a ``RouterActor`` in sglang env to own the router process. """ - router_cfg = sglang_cfg["sglang_router"] - if router_cfg["sglang_router_ip"] is not None: + router_cfg = sglang_cfg.get("sglang_router") or {} + if router_cfg.get("sglang_router_ip") is not None: return router_cfg["sglang_router_ip"], router_cfg["sglang_router_port"], None router_actor = RouterActor.options( diff --git a/nemo_rl/models/generation/redesign/sglang_worker.py b/nemo_rl/models/generation/redesign/sglang_worker.py index 02aea6b3c7..092ec29a14 100644 --- a/nemo_rl/models/generation/redesign/sglang_worker.py +++ b/nemo_rl/models/generation/redesign/sglang_worker.py @@ -590,6 +590,7 @@ def _compute_server_args( "mem_fraction_static", "allow_auto_truncate", "disable_piecewise_cuda_graph", + "disable_cuda_graph", ]: if key in sglang_cfg["sglang_cfg"]: kwargs[key] = sglang_cfg["sglang_cfg"][key] diff --git a/tests/test_suites/llm/grpo-math-1b-redesign-1n4g-sglang.sh b/tests/test_suites/llm/grpo-math-1b-redesign-1n4g-sglang.sh new file mode 100755 index 0000000000..7ed949b235 --- /dev/null +++ b/tests/test_suites/llm/grpo-math-1b-redesign-1n4g-sglang.sh @@ -0,0 +1,79 @@ +#!/bin/bash +set -eou pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../../..) + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +CKPT_DIR=$EXP_DIR/ckpts +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log + +# This test targets the redesigned sglang backend config which lives outside +# examples/configs/recipes/llm, so we set CONFIG_PATH explicitly rather than +# sourcing common.env. +CONFIG_PATH=$PROJECT_ROOT/examples/configs/grpo_math_1B_redesign_sglang.yaml +if [[ ! -f $CONFIG_PATH ]]; then + echo "[ERROR] Config file $CONFIG_PATH not found" + exit 1 +fi + +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-4,5,6,7} + +exit_if_max_steps_reached() { + STEPS_SO_FAR=$(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS || echo 0) + if [[ $STEPS_SO_FAR -ge $MAX_STEPS ]]; then + echo "[INFO] Target step $MAX_STEPS reached, skipping run" + exit 0 + fi + echo "[INFO] Steps so far: $STEPS_SO_FAR, running till $MAX_STEPS steps" +} + +if [[ -n "${TEST_DRYRUN:-}" ]]; then + echo "[INFO] TEST_DRYRUN mode: used for testing" + exit +fi + +mkdir -p $EXP_DIR $LOG_DIR $CKPT_DIR + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=450 +MAX_STEPS=450 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=150 # ~13.7s/step without piecewise CUDA graphs +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'mean(data["timing/train/total_step_time"], 2) < 25' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh new file mode 100755 index 0000000000..b92891150e --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh @@ -0,0 +1,45 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-4,5,6,7} + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=450 +MAX_STEPS=450 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=150 # ~13.7s/step without piecewise CUDA graphs +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +# Same thresholds as the 1n8g fsdp2tp1-sglang recipe for alignment verification +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'mean(data["timing/train/total_step_time"], 2) < 25' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi From db93f08856fe84fd1e612925b226a0c69c0605d4 Mon Sep 17 00:00:00 2001 From: zhihaow6 Date: Thu, 16 Apr 2026 09:02:33 -0700 Subject: [PATCH 36/36] update --- ...struct-1n2g-fsdp2tp1-redesign-sglang.yaml} | 8 ++-- ...math-1.5b-instruct-1n2g-fsdp2tp1-vllm.yaml | 38 ++++++++++++++++ ...instruct-1n2g-fsdp2tp1-redesign-sglang.sh} | 0 ...5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.sh | 44 +++++++++++++++++++ 4 files changed, 86 insertions(+), 4 deletions(-) rename examples/configs/recipes/llm/{grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml => grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.yaml} (85%) create mode 100644 examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.yaml rename tests/test_suites/llm/{grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh => grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.sh} (100%) create mode 100644 tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.sh diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.yaml similarity index 85% rename from examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml rename to examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.yaml index ff3a82e948..db638cc136 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.yaml @@ -5,7 +5,7 @@ grpo: val_batch_size: 128 checkpointing: - checkpoint_dir: results/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang + checkpoint_dir: results/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang policy: model_name: Qwen/Qwen2.5-Math-1.5B-Instruct @@ -37,7 +37,7 @@ policy: disable_cuda_graph: true sglang_server: needs_offload: true - cpu_weight_backup: false + cpu_weight_backup: true sglang_server_concurrency: 64 num_gpus: 2 num_gpus_per_engine: 2 @@ -47,12 +47,12 @@ data: max_input_seq_length: 512 logger: - log_dir: logs/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang + log_dir: logs/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang wandb_enabled: true tensorboard_enabled: true wandb: project: nemo-rl - name: grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang + name: grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang cluster: gpus_per_node: 2 diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.yaml new file mode 100644 index 0000000000..14219f0508 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.yaml @@ -0,0 +1,38 @@ +defaults: ../../grpo_math_1B.yaml + +grpo: + max_num_steps: 450 + val_batch_size: 128 + +checkpointing: + checkpoint_dir: results/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm + +policy: + model_name: Qwen/Qwen2.5-Math-1.5B-Instruct + tokenizer: + name: Qwen/Qwen2.5-Math-1.5B-Instruct + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 + tensor_parallel_size: 2 + +data: + max_input_seq_length: 512 + +logger: + log_dir: logs/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm + +cluster: + gpus_per_node: 2 + num_nodes: 1 diff --git a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.sh similarity index 100% rename from tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n4g-fsdp2tp1-redesign-sglang.sh rename to tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-redesign-sglang.sh diff --git a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.sh b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.sh new file mode 100644 index 0000000000..d68fc68d3d --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n2g-fsdp2tp1-vllm.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-6,7} + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=450 +MAX_STEPS=450 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=150 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'mean(data["timing/train/total_step_time"], 2) < 25' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi