From 38912bdd6043cc19d7871b4c4854ccfca10eec04 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 18 Feb 2025 12:57:00 -0600 Subject: [PATCH 01/34] Rollout manager with exception counter & chunking fix; .gitignore update --- .gitignore | 2 + ldp/alg/rollout.py | 53 ++++++++++++++++--- ldp/graph/async_torch.py | 2 +- ldp/nn/__init__.py | 3 ++ ....jinja => llama3_chat_template_test.jinja} | 0 ldp/nn/handlers/chunking.py | 5 +- 6 files changed, 55 insertions(+), 10 deletions(-) rename ldp/nn/chat_templates/{llama3_chat_template_ori.jinja => llama3_chat_template_test.jinja} (100%) diff --git a/.gitignore b/.gitignore index 01a52fbd..57d27823 100644 --- a/.gitignore +++ b/.gitignore @@ -297,3 +297,5 @@ cython_debug/ # Version files made by setuptools_scm **/version.py + +.vscode/ \ No newline at end of file diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index ff06f266..227982f8 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -2,6 +2,7 @@ import itertools import logging import uuid +from collections import Counter from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager, nullcontext from typing import Any, TypeVar, overload @@ -24,6 +25,7 @@ class CaughtError(Exception): """Base class for reraised exceptions when catching is enabled.""" def __init__(self, original_exc: Exception): + super().__init__(str(original_exc)) self.original_exc = original_exc exc_type = "undefined" @@ -39,12 +41,13 @@ class EnvError(CaughtError): @contextmanager def reraise_exc_as(reraise: type[CaughtError], enabled: bool) -> Iterator[None]: + """Context manager that reraises exceptions as a custom CaughtError type if enabled.""" try: yield except Exception as e: if enabled: - error_details = format_error_details(e) - logger.exception(f"Caught {reraise.exc_type} exception:\n{error_details}") + # Minimal logging instead of spamming. Detailed error stored in the trajectory's metadata. + logger.debug(f"Reraising {reraise.exc_type} exception.") raise reraise(e) from None raise @@ -193,14 +196,50 @@ async def _sample_trajectories_from_envs( max_steps: int | None = None, ) -> list[Trajectory]: self.traj_buffer.clear() + exception_counter = Counter() - traj_ids = [uuid.uuid4().hex for _ in range(len(environments))] - await asyncio.gather( - *( - self._rollout(*args, max_steps=max_steps) - for args in zip(traj_ids, environments, strict=True) + traj_ids = [uuid.uuid4().hex for _ in environments] + + # Create all tasks first + tasks = [ + asyncio.create_task( + self._rollout(traj_id, env, max_steps=max_steps) ) + for traj_id, env in zip(traj_ids, environments, strict=True) + ] + + # Use a single line bar_format to avoid multiline spam. + from tqdm import tqdm + bar_format = ( + "{l_bar}{bar} {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" + " {postfix}" ) + + with tqdm( + total=len(tasks), + desc="Rollouts", + unit="rollout", + bar_format=bar_format, + ) as pbar: + for task in asyncio.as_completed(tasks): + trajectory = await task + pbar.update(1) + # Check if this trajectory ended with an exception + if trajectory.steps: + last_step = trajectory.steps[-1] + if last_step.metadata.get("exception"): + # We'll keep it short but still have something to categorize + exc_str = last_step.metadata["exception"][:500].replace('"', "'") + exception_counter[exc_str] += 1 + num_exceptions = sum(exception_counter.values()) + pbar.set_postfix({"num_exceptions": num_exceptions}) + + # Final summary of exceptions (if any) + if exception_counter: + logger.info("Caught exceptions:") + logger.info("{:<6} {:<50}".format("Count", "Exception")) + for exc, count in exception_counter.items(): + logger.info("{:<6} {:<50}".format(count, exc)) return [self.traj_buffer[traj_id] for traj_id in traj_ids] async def _rollout( diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index 55ef0e9b..c614ab0b 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -127,7 +127,7 @@ async def _maybe_process_batch(self) -> None: if ( len(self._work_buffer) >= self.batch_size - or now - self._work_buffer[0][0] > self.timeout + or (now - self._work_buffer[0][0] > self.timeout) and len(self._work_buffer) > 0 ): # if we're over batch size or have at least one input waiting for # more than timeout, pull out a batch to run diff --git a/ldp/nn/__init__.py b/ldp/nn/__init__.py index 46e6dcd0..73b300bf 100644 --- a/ldp/nn/__init__.py +++ b/ldp/nn/__init__.py @@ -1,6 +1,7 @@ from .agent.simple_local_agent import AgentLMConfig, SimpleLocalLLMAgent from .graph.llm_call_op import LocalLLMCallOp from .handlers.chunking import TensorChunker +from .handlers.module_handler import AsyncModuleHandler, ModuleExecutionInterface from .handlers.transformer_handler import ( AsyncTransformer, AsyncTransformerInterface, @@ -20,12 +21,14 @@ __all__ = [ "AgentLMConfig", + "AsyncModuleHandler", "AsyncTransformer", "AsyncTransformerInterface", "ExecutionMode", "LMConfig", "LMType", "LocalLLMCallOp", + "ModuleExecutionInterface", "ParallelAsyncTransformer", "ParallelModeConfig", "ParallelTransformerHandler", diff --git a/ldp/nn/chat_templates/llama3_chat_template_ori.jinja b/ldp/nn/chat_templates/llama3_chat_template_test.jinja similarity index 100% rename from ldp/nn/chat_templates/llama3_chat_template_ori.jinja rename to ldp/nn/chat_templates/llama3_chat_template_test.jinja diff --git a/ldp/nn/handlers/chunking.py b/ldp/nn/handlers/chunking.py index 38fdbe96..65771b5a 100644 --- a/ldp/nn/handlers/chunking.py +++ b/ldp/nn/handlers/chunking.py @@ -159,8 +159,9 @@ def _split_value(self, value): for i in range(self.num_chunks): if i >= len(chunks): # Chunk 0 will always exist, and we need only a batch of one ([:1]) - # to activate the model - chunks.append(torch.full_like(chunks[0][:1], self.dummy_value)) + # to activate the model. + # We use real data to avoid errors in the model expecting certain token structure. + chunks.append(chunks[0][:1]) dummy_chunk_flags.append(True) else: dummy_chunk_flags.append(False) From 13c962224ad7147310b718d912503aebd9ea4290 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 18 Feb 2025 14:13:45 -0600 Subject: [PATCH 02/34] nits --- .gitignore | 2 +- ldp/alg/rollout.py | 18 ++++++++---------- ldp/graph/async_torch.py | 12 +++++++++--- ldp/nn/handlers/chunking.py | 2 +- ldp/nn/handlers/transformer_handler.py | 15 +++++++++++++++ 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 57d27823..5293aadb 100644 --- a/.gitignore +++ b/.gitignore @@ -298,4 +298,4 @@ cython_debug/ # Version files made by setuptools_scm **/version.py -.vscode/ \ No newline at end of file +.vscode/ diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index 227982f8..a0c2a883 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -8,10 +8,10 @@ from typing import Any, TypeVar, overload from aviary.core import Environment, Message +from tqdm import tqdm from ldp.agent import Agent from ldp.data_structures import Trajectory, Transition -from ldp.utils import format_error_details from .callbacks import Callback @@ -196,20 +196,16 @@ async def _sample_trajectories_from_envs( max_steps: int | None = None, ) -> list[Trajectory]: self.traj_buffer.clear() - exception_counter = Counter() + exception_counter: Counter = Counter() traj_ids = [uuid.uuid4().hex for _ in environments] # Create all tasks first tasks = [ - asyncio.create_task( - self._rollout(traj_id, env, max_steps=max_steps) - ) + asyncio.create_task(self._rollout(traj_id, env, max_steps=max_steps)) for traj_id, env in zip(traj_ids, environments, strict=True) ] - # Use a single line bar_format to avoid multiline spam. - from tqdm import tqdm bar_format = ( "{l_bar}{bar} {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" " {postfix}" @@ -229,7 +225,8 @@ async def _sample_trajectories_from_envs( last_step = trajectory.steps[-1] if last_step.metadata.get("exception"): # We'll keep it short but still have something to categorize - exc_str = last_step.metadata["exception"][:500].replace('"', "'") + exc_str: str = last_step.metadata["exception"][:500] + exc_str = exc_str.replace('"', "'") exception_counter[exc_str] += 1 num_exceptions = sum(exception_counter.values()) pbar.set_postfix({"num_exceptions": num_exceptions}) @@ -237,9 +234,10 @@ async def _sample_trajectories_from_envs( # Final summary of exceptions (if any) if exception_counter: logger.info("Caught exceptions:") - logger.info("{:<6} {:<50}".format("Count", "Exception")) + logger.info("%-6s %-50s", "Count", "Exception") for exc, count in exception_counter.items(): - logger.info("{:<6} {:<50}".format(count, exc)) + logger.info("%-6d %-50s", count, exc) + return [self.traj_buffer[traj_id] for traj_id in traj_ids] async def _rollout( diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index c614ab0b..d9adea65 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -120,14 +120,20 @@ async def _maybe_process_batch(self) -> None: If neither condition is met, do nothing. """ + # Technically should not happen, but if a coroutine crashes, it could release + # self._lock before placing results in _results_buffer and additional process + # coming inside will crash. + if not self._work_buffer: + return + now = time.time() # sort by oldest requests first self._work_buffer.sort(key=operator.itemgetter(0)) - if ( - len(self._work_buffer) >= self.batch_size - or (now - self._work_buffer[0][0] > self.timeout) and len(self._work_buffer) > 0 + if len(self._work_buffer) >= self.batch_size or ( + (now - self._work_buffer[0][0] > self.timeout) + and len(self._work_buffer) > 0 ): # if we're over batch size or have at least one input waiting for # more than timeout, pull out a batch to run diff --git a/ldp/nn/handlers/chunking.py b/ldp/nn/handlers/chunking.py index 65771b5a..fe34da68 100644 --- a/ldp/nn/handlers/chunking.py +++ b/ldp/nn/handlers/chunking.py @@ -159,7 +159,7 @@ def _split_value(self, value): for i in range(self.num_chunks): if i >= len(chunks): # Chunk 0 will always exist, and we need only a batch of one ([:1]) - # to activate the model. + # to activate the model. # We use real data to avoid errors in the model expecting certain token structure. chunks.append(chunks[0][:1]) dummy_chunk_flags.append(True) diff --git a/ldp/nn/handlers/transformer_handler.py b/ldp/nn/handlers/transformer_handler.py index 8c53ede5..0bb84378 100644 --- a/ldp/nn/handlers/transformer_handler.py +++ b/ldp/nn/handlers/transformer_handler.py @@ -1,5 +1,6 @@ from __future__ import annotations +import atexit import logging import os import socket @@ -193,6 +194,14 @@ async def __call__( # type: ignore[override] @staticmethod def model_generate(model: PreTrainedModel, *args, **kwargs): """A method that can be used as module_call_fn to sample from an LLM.""" + if dist.get_world_size() > 1: + synced_gpus = kwargs.pop("synced_gpus", None) + if synced_gpus is None: + logger.debug("synced_gpus not defined, defaulting to True.") + elif not synced_gpus: + raise ValueError("synced_gpus must be True when using FSDP.") + kwargs["synced_gpus"] = True + # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. with FullyShardedDataParallel.summon_full_params(model, recurse=False): @@ -463,6 +472,8 @@ def __init__(self, config: TransformerHandlerConfig): self._initialized = True + atexit.register(self.teardown) + # don't call AsyncTorchModule.__init__ because we don't need to set up module[_call_fn] AsyncBufferedWorker.__init__( self, @@ -484,6 +495,10 @@ def _init_local_cluster( # lazy import since dask-cuda only works on Linux machines from dask_cuda import LocalCUDACluster + # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection + # post initialization. This prevents issues with forked processes wrongly detecting the + # default GPU as cuda:0 + os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" self.cluster = LocalCUDACluster( n_workers=parallel_mode_config.num_workers, threads_per_worker=parallel_mode_config.num_cpus_per_worker, From a7c1bf2736790536f38a42a1cbdd31d1dde24b2e Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 18 Feb 2025 14:15:27 -0600 Subject: [PATCH 03/34] nits --- ldp/alg/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index a0c2a883..975de0db 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -225,7 +225,7 @@ async def _sample_trajectories_from_envs( last_step = trajectory.steps[-1] if last_step.metadata.get("exception"): # We'll keep it short but still have something to categorize - exc_str: str = last_step.metadata["exception"][:500] + exc_str: str = str(last_step.metadata["exception"])[:500] exc_str = exc_str.replace('"', "'") exception_counter[exc_str] += 1 num_exceptions = sum(exception_counter.values()) From ae71669f05f5cfb0f6c9e3b5920cfd05e8f1fc03 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Sun, 23 Feb 2025 08:47:29 -0600 Subject: [PATCH 04/34] nits --- ldp/alg/rollout.py | 1 + ldp/graph/async_torch.py | 1 - ldp/nn/agent/simple_local_agent.py | 20 ++++++++++++++++++++ ldp/nn/handlers/transformer_handler.py | 9 ++++++++- 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index 975de0db..31d9f9e0 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -245,6 +245,7 @@ async def _rollout( traj_id: str, env: Environment, max_steps: int | None, + max_tokens: int | None = None, # <-- new argument ) -> Trajectory: trajectory = Trajectory(traj_id=traj_id) diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index d9adea65..f0b2f2df 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -133,7 +133,6 @@ async def _maybe_process_batch(self) -> None: if len(self._work_buffer) >= self.batch_size or ( (now - self._work_buffer[0][0] > self.timeout) - and len(self._work_buffer) > 0 ): # if we're over batch size or have at least one input waiting for # more than timeout, pull out a batch to run diff --git a/ldp/nn/agent/simple_local_agent.py b/ldp/nn/agent/simple_local_agent.py index 9954b03e..f59fbf3a 100644 --- a/ldp/nn/agent/simple_local_agent.py +++ b/ldp/nn/agent/simple_local_agent.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist +from litellm import token_counter from aviary.core import Message, Tool, ToolRequestMessage from pydantic import Field, field_validator @@ -41,6 +42,11 @@ class AgentLMConfig(_LMConfig): "are better defaults than HF's.", validate_default=True, ) + + max_traj_token_count: int | None = Field( + default=None, + description="If set, raise an error if the total tokens in the trajectory exceed this value." + ) @field_validator("llm_call_kwargs") @classmethod @@ -110,6 +116,20 @@ async def get_asv( # Update state messages with result and return the new state next_state.messages = [*next_state.messages, result.value] + + + import ipdb; ipdb.set_trace() + if self.llm_model.max_traj_token_count is not None: + total_tokens = token_counter( + model=self.llm_model.llm_for_sft, # or any field referencing the model name + messages=next_state.messages, + tools=next_state.tools, + ) + if total_tokens > self.llm_model.max_traj_token_count: + raise ValueError( + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" + ) + return cast(OpResult[ToolRequestMessage], result), next_state, 0.0 # TODO: maybe remove these recomputation methods. I added them to debug some things. But idk, diff --git a/ldp/nn/handlers/transformer_handler.py b/ldp/nn/handlers/transformer_handler.py index 0bb84378..47f5253f 100644 --- a/ldp/nn/handlers/transformer_handler.py +++ b/ldp/nn/handlers/transformer_handler.py @@ -47,6 +47,7 @@ else: from typing_extensions import overload # noqa: UP035 +logger = logging.getLogger(__name__) config.set({ # We have no use for rebooting workers in aviary for now, and rebooting workers @@ -60,7 +61,10 @@ "distributed.comm.timeouts.tcp": "300s", }) -logger = logging.getLogger(__name__) +compression = os.getenv("USE_DASK_COMPRESSION") +if compression is not None: + config.set({"distributed.comm.compression": compression}) + logger.info(f"Setting Dask compression to {compression}") TReturn = TypeVar("TReturn") TParams = ParamSpec("TParams") @@ -201,6 +205,9 @@ def model_generate(model: PreTrainedModel, *args, **kwargs): elif not synced_gpus: raise ValueError("synced_gpus must be True when using FSDP.") kwargs["synced_gpus"] = True + if os.getenv("USE_DASK_BARRIER"): + logger.info("Waiting for all workers to reach this point.") + dist.barrier() # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. From 39b0fc05b83c5bcb1bace29a37e76d0dd5d59c12 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Sun, 23 Feb 2025 09:17:04 -0600 Subject: [PATCH 05/34] nits --- ldp/alg/rollout.py | 10 ++-------- ldp/nn/handlers/chunking.py | 3 ++- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/ldp/alg/rollout.py b/ldp/alg/rollout.py index 31d9f9e0..dc9acb5f 100644 --- a/ldp/alg/rollout.py +++ b/ldp/alg/rollout.py @@ -206,16 +206,11 @@ async def _sample_trajectories_from_envs( for traj_id, env in zip(traj_ids, environments, strict=True) ] - bar_format = ( - "{l_bar}{bar} {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" - " {postfix}" - ) - with tqdm( total=len(tasks), desc="Rollouts", unit="rollout", - bar_format=bar_format, + ncols=0, ) as pbar: for task in asyncio.as_completed(tasks): trajectory = await task @@ -225,8 +220,7 @@ async def _sample_trajectories_from_envs( last_step = trajectory.steps[-1] if last_step.metadata.get("exception"): # We'll keep it short but still have something to categorize - exc_str: str = str(last_step.metadata["exception"])[:500] - exc_str = exc_str.replace('"', "'") + exc_str: str = str(last_step.metadata["exception"])[:500].replace('"', "'") exception_counter[exc_str] += 1 num_exceptions = sum(exception_counter.values()) pbar.set_postfix({"num_exceptions": num_exceptions}) diff --git a/ldp/nn/handlers/chunking.py b/ldp/nn/handlers/chunking.py index fe34da68..147b6dd8 100644 --- a/ldp/nn/handlers/chunking.py +++ b/ldp/nn/handlers/chunking.py @@ -160,7 +160,8 @@ def _split_value(self, value): if i >= len(chunks): # Chunk 0 will always exist, and we need only a batch of one ([:1]) # to activate the model. - # We use real data to avoid errors in the model expecting certain token structure. + # We use the first element of the existing chunks as real data to avoid + # errors in the model that may expect a specific token structure. chunks.append(chunks[0][:1]) dummy_chunk_flags.append(True) else: From da92fbf9b87f89b09b2a270dfec7d0a463602e3e Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Sun, 23 Feb 2025 09:28:46 -0600 Subject: [PATCH 06/34] nits --- src/ldp/graph/async_torch.py | 7 +- .../llama3_chat_template_test.jinja | 113 ------------------ 2 files changed, 4 insertions(+), 116 deletions(-) delete mode 100644 src/ldp/nn/chat_templates/llama3_chat_template_test.jinja diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index f0b2f2df..f6597bb9 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -122,7 +122,7 @@ async def _maybe_process_batch(self) -> None: """ # Technically should not happen, but if a coroutine crashes, it could release # self._lock before placing results in _results_buffer and additional process - # coming inside will crash. + # coming inside this func will crash as self._work_buffer will be empty. if not self._work_buffer: return @@ -131,8 +131,9 @@ async def _maybe_process_batch(self) -> None: # sort by oldest requests first self._work_buffer.sort(key=operator.itemgetter(0)) - if len(self._work_buffer) >= self.batch_size or ( - (now - self._work_buffer[0][0] > self.timeout) + if ( + len(self._work_buffer) >= self.batch_size + or now - self._work_buffer[0][0] > self.timeout ): # if we're over batch size or have at least one input waiting for # more than timeout, pull out a batch to run diff --git a/src/ldp/nn/chat_templates/llama3_chat_template_test.jinja b/src/ldp/nn/chat_templates/llama3_chat_template_test.jinja deleted file mode 100644 index e2807a55..00000000 --- a/src/ldp/nn/chat_templates/llama3_chat_template_test.jinja +++ /dev/null @@ -1,113 +0,0 @@ -{{- bos_token }} -{%- if custom_tools is defined %} - {%- set tools = custom_tools %} -{%- endif %} -{%- if not tools_in_user_message is defined %} - {%- set tools_in_user_message = true %} -{%- endif %} -{%- if not date_string is defined %} - {%- set date_string = "26 Jul 2024" %} -{%- endif %} -{%- if not tools is defined %} - {%- set tools = none %} -{%- endif %} - -{#- This block extracts the system message, so we can slot it into the right place. #} -{%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} - {%- set messages = messages[1:] %} -{%- else %} - {%- set system_message = "" %} -{%- endif %} - -{#- System message + builtin tools #} -{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} -{%- if builtin_tools is defined or tools is not none %} - {{- "Environment: ipython\n" }} -{%- endif %} -{%- if builtin_tools is defined %} - {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} -{%- endif %} -{{- "Cutting Knowledge Date: December 2023\n" }} -{{- "Today Date: " + date_string + "\n\n" }} -{%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value, "thought": optional succeint reasoning processes leading to calling this tool, }.' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} -{%- endif %} -{{- system_message }} -{{- "<|eot_id|>" }} - -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and not tools is none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} -{%- endif %} - {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} - {{- "Given the following functions, please respond with a JSON for a function call " }} - {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value, "thought": succeint reasoning processes leading to calling this tool, }.' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- first_user_message + "<|eot_id|>"}} -{%- endif %} - -{%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {%- if message['role'] == 'assistant' %} - {% generation %}{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}{% endgeneration %} - {%- else %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }} - {%- endif %} - {%- elif 'tool_calls' in message %} - {%- if not message.tool_calls|length == 1 %} - {{- raise_exception("This model only supports single tool-calls at once!") }} - {%- endif %} - {%- set tool_call = message.tool_calls[0].function %} - {% generation %}{%- if builtin_tools is defined and tool_call.name in builtin_tools %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} - {{- "<|python_tag|>" + tool_call.name + ".call(" }} - {%- for arg_name, arg_val in tool_call.arguments | items %} - {{- arg_name + '="' + arg_val + '"' }} - {%- if not loop.last %} - {{- ", " }} - {%- endif %} - {%- endfor %} - {{- ")" }} - {%- else %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} - {{- '{"name": "' + tool_call.name + '", ' }} - {{- '"parameters": ' }} - {{- tool_call.arguments | tojson }} - {{- "}" }} - {%- endif %} - {%- if builtin_tools is defined %} - {#- This means we're in ipython mode #} - {{- "<|eom_id|>" }} - {%- else %} - {{- "<|eot_id|>" }} - {%- endif %}{% endgeneration %} - {%- elif message.role == "tool" or message.role == "ipython" %} - {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is mapping or message.content is iterable %} - {{- message.content | tojson }} - {%- else %} - {{- message.content }} - {%- endif %} - {{- "<|eot_id|>" }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {% generation %}{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endgeneration %} -{%- endif %} From 2ba58981a2171a438e6511d690b0f146a53c1574 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 24 Feb 2025 06:43:51 -0600 Subject: [PATCH 07/34] nits --- src/ldp/alg/rollout.py | 4 ++- src/ldp/nn/agent/simple_local_agent.py | 31 +++++++++++++++------- src/ldp/nn/handlers/transformer_handler.py | 6 +---- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index dc9acb5f..2b1a1b1d 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -220,7 +220,9 @@ async def _sample_trajectories_from_envs( last_step = trajectory.steps[-1] if last_step.metadata.get("exception"): # We'll keep it short but still have something to categorize - exc_str: str = str(last_step.metadata["exception"])[:500].replace('"', "'") + exc_str: str = str(last_step.metadata["exception"])[ + :500 + ].replace('"', "'") exception_counter[exc_str] += 1 num_exceptions = sum(exception_counter.values()) pbar.set_postfix({"num_exceptions": num_exceptions}) diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 011a4d48..f30da850 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -2,8 +2,8 @@ import torch import torch.distributed as dist -from litellm import token_counter from aviary.core import Message, Tool, ToolRequestMessage +from litellm import token_counter from pydantic import Field, field_validator from ldp.agent import Agent, SimpleAgentState @@ -41,10 +41,9 @@ class AgentLMConfig(_LMConfig): "are better defaults than HF's.", validate_default=True, ) - max_traj_token_count: int | None = Field( default=None, - description="If set, raise an error if the total tokens in the trajectory exceed this value." + description="If set, raise an error if the total tokens in the trajectory exceed this value.", ) @field_validator("llm_call_kwargs") @@ -115,20 +114,32 @@ async def get_asv( # Update state messages with result and return the new state next_state.messages = [*next_state.messages, result.value] - - - import ipdb; ipdb.set_trace() + if self.llm_model.max_traj_token_count is not None: + messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer( + next_state.messages + ) + tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer( + next_state.tools + ) total_tokens = token_counter( - model=self.llm_model.llm_for_sft, # or any field referencing the model name - messages=next_state.messages, - tools=next_state.tools, + model=self.llm_model.model, # or any field referencing the model name + messages=messages_for_tokenizer, + tools=tools_for_tokenizer, + ) + # TODO remove + print( + "The traj size is %d tokens, with a limit of %d tokens" + % (total_tokens, self.llm_model.max_traj_token_count) ) if total_tokens > self.llm_model.max_traj_token_count: + import ipdb + + ipdb.set_trace() # TODO remove raise ValueError( f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" ) - + return cast(OpResult[ToolRequestMessage], result), next_state, 0.0 # TODO: maybe remove these recomputation methods. I added them to debug some things. But idk, diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index fb3a5d46..6fcc07e3 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -60,11 +60,6 @@ "distributed.comm.timeouts.tcp": "300s", }) -compression = os.getenv("USE_DASK_COMPRESSION") -if compression is not None: - config.set({"distributed.comm.compression": compression}) - logger.info(f"Setting Dask compression to {compression}") - TReturn = TypeVar("TReturn") TParams = ParamSpec("TParams") @@ -204,6 +199,7 @@ def model_generate(model: PreTrainedModel, *args, **kwargs): elif not synced_gpus: raise ValueError("synced_gpus must be True when using FSDP.") kwargs["synced_gpus"] = True + # TODO remove if os.getenv("USE_DASK_BARRIER"): logger.info("Waiting for all workers to reach this point.") dist.barrier() From 62977864e4755dac5587d19d8ee290d18c863d0f Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 3 Mar 2025 14:01:04 -0600 Subject: [PATCH 08/34] nits --- src/ldp/nn/agent/simple_local_agent.py | 47 +++++++++++----------- src/ldp/nn/handlers/transformer_handler.py | 34 +++++++++++----- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index f30da850..50ee80a5 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -1,3 +1,4 @@ +import logging from typing import cast import torch @@ -18,6 +19,8 @@ ) from ldp.nn.lm_config import LMConfig as _LMConfig +logger = logging.getLogger(__name__) + class AgentLMConfig(_LMConfig): """Adds some additional configuration options for running an LM in an Op.""" @@ -94,6 +97,8 @@ async def get_asv( else next_state.messages ) + self._validate_token_count(messages, next_state.tools) + # Execute the LLM operation call result = cast( OpResult[Message | ToolRequestMessage], @@ -114,33 +119,27 @@ async def get_asv( # Update state messages with result and return the new state next_state.messages = [*next_state.messages, result.value] + self._validate_token_count(next_state.messages, next_state.tools) - if self.llm_model.max_traj_token_count is not None: - messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer( - next_state.messages - ) - tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer( - next_state.tools - ) - total_tokens = token_counter( - model=self.llm_model.model, # or any field referencing the model name - messages=messages_for_tokenizer, - tools=tools_for_tokenizer, + return cast(OpResult[ToolRequestMessage], result), next_state, 0.0 + + def _validate_token_count(self, messages: list[Message], tools: list[Tool]): + if self.llm_model.max_traj_token_count is None: + return + messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages) + tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools) + total_tokens = token_counter( + model=self.llm_model.model, + messages=messages_for_tokenizer, + tools=tools_for_tokenizer, + ) + if total_tokens > self.llm_model.max_traj_token_count: + logger.error( + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" ) - # TODO remove - print( - "The traj size is %d tokens, with a limit of %d tokens" - % (total_tokens, self.llm_model.max_traj_token_count) + raise ValueError( + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" ) - if total_tokens > self.llm_model.max_traj_token_count: - import ipdb - - ipdb.set_trace() # TODO remove - raise ValueError( - f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" - ) - - return cast(OpResult[ToolRequestMessage], result), next_state, 0.0 # TODO: maybe remove these recomputation methods. I added them to debug some things. But idk, # maybe they'll come in handy later. diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 6fcc07e3..1edcb54c 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -17,7 +17,7 @@ import torch.distributed as dist import tree from dask import config -from dask.distributed import Client +from dask.distributed import Client, as_completed, wait from pydantic import BaseModel, ConfigDict, Field, field_validator from torch import nn from torch.cuda import nccl @@ -196,13 +196,10 @@ def model_generate(model: PreTrainedModel, *args, **kwargs): synced_gpus = kwargs.pop("synced_gpus", None) if synced_gpus is None: logger.debug("synced_gpus not defined, defaulting to True.") + kwargs["synced_gpus"] = True elif not synced_gpus: raise ValueError("synced_gpus must be True when using FSDP.") - kwargs["synced_gpus"] = True - # TODO remove - if os.getenv("USE_DASK_BARRIER"): - logger.info("Waiting for all workers to reach this point.") - dist.barrier() + # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. @@ -585,7 +582,7 @@ def get_cuda_visible_devices() -> int | None: futures.append(future_op) worker_ids.append(worker_id) - self.handlers = self.client.gather(futures) + self.handlers = self.client_gather(futures) self.worker_ids = worker_ids async def __call__( @@ -656,7 +653,7 @@ def _submit_and_gather( self.handlers, self.worker_ids, split_args, split_kwargs, strict=True ) ] - results = self.client.gather(futures) + results = self.client_gather(futures) results = cast(list[TReturn], [res.result().result() for res in results]) if split_data: @@ -767,13 +764,32 @@ def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: def teardown(self) -> None: if self._initialized: - self.client.close() + self.client.shutdown() self.cluster.close() self._initialized = False def __del__(self) -> None: self.teardown() + def client_gather(self, futures): + """Gather results from futures, propagating exceptions as they arrive. + + Unlike client.gather() which waits for all futures to complete before raising + any exceptions, this method processes futures as they complete and raises + exceptions immediately. This is crucial when using FSDP where workers may + be stuck waiting for each other where one worker crashes, causing long hangs. + """ + # Initialize a list to hold results + results = [None] * len(futures) + for completed_future, result in as_completed( + futures, with_results=True, raise_errors=True + ): + # Find the index of the completed future + index = futures.index(completed_future) + # Store the result directly from as_completed + results[index] = result + return results + # Helpers From e4ccb45fb2f0dda7c3ebad7e94de3c762e4dae17 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 4 Mar 2025 06:14:58 -0600 Subject: [PATCH 09/34] nits --- src/ldp/alg/rollout.py | 71 ++++++++++++++++++++++++--------- src/ldp/graph/async_torch.py | 6 --- src/ldp/nn/handlers/chunking.py | 3 +- tests/test_nn_models.py | 17 ++++---- 4 files changed, 60 insertions(+), 37 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 2b1a1b1d..210437de 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -79,6 +79,8 @@ async def sample_trajectories( # noqa: D418 environment_factory: Callable[[], TEnv], batch_size: int = 1, max_steps: int | None = None, + *, + log_exceptions_immediately: bool = False, ) -> list[tuple[Trajectory, TEnv]]: """Run rollouts in parallel, using a factory to construct environments. @@ -92,6 +94,8 @@ async def sample_trajectories( # noqa: D418 an environment instance batch_size (int, optional): Defaults to 1. max_steps (int | None, optional): Max steps per rollout. Defaults to None (see above). + log_exceptions_immediately (bool, optional): Whether to log exceptions as they occur + or only at the end of the rollouts. Defaults to False. Returns: list[tuple[Trajectory, Environment]]: A list of (trajectory, environment) tuples: one per rollout. @@ -102,6 +106,8 @@ async def sample_trajectories( # noqa: D418 self, environments: Sequence[Environment], max_steps: int | None = None, + *, + log_exceptions_immediately: bool = False, ) -> list[Trajectory]: """Run rollouts in parallel on a list of provided environments. @@ -109,26 +115,35 @@ async def sample_trajectories( # noqa: D418 environments: A list of environments to run rollouts on. max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. + log_exceptions_immediately (bool, optional): Whether to log exceptions as they occur + or only at the end of the rollouts. Defaults to False. """ - async def sample_trajectories(self, **kwargs): - if "environment_factory" in kwargs: - assert "environments" not in kwargs, ( - "Cannot use environment_factory with environments" - ) - + async def sample_trajectories( + self, + environment_factory: Callable[[], Environment] | None = None, + environments: Sequence[Environment] | None = None, + batch_size: int = 1, + max_steps: int | None = None, + *, + log_exceptions_immediately: bool = False, + ) -> list[tuple[Trajectory, Environment]] | list[Trajectory]: + """Sample trajectories from environments, either via factory or pre-created.""" + if environment_factory is not None: + assert environments is None, "Cannot use environment_factory with environments" return await self._sample_trajectories_from_env_factory( - kwargs["environment_factory"], - kwargs.get("batch_size", 1), - kwargs.get("max_steps"), + environment_factory, + batch_size, + max_steps, + log_exceptions_immediately=log_exceptions_immediately, ) - if "environments" in kwargs: - assert "environment_factory" not in kwargs, ( - "Cannot use environments with environment_factory" - ) + if environments is not None: + assert environment_factory is None, "Cannot use environments with environment_factory" return await self._sample_trajectories_from_envs( - kwargs["environments"], kwargs.get("max_steps") + environments, + max_steps, + log_exceptions_immediately=log_exceptions_immediately, ) raise TypeError( @@ -141,6 +156,8 @@ async def _sample_trajectories_from_env_factory( environment_factory: Callable[[], Environment], batch_size: int = 1, max_steps: int | None = None, + *, + log_exceptions_immediately: bool = False, ) -> list[tuple[Trajectory, Environment]]: self.traj_buffer.clear() @@ -156,6 +173,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): traj_id=uuid.uuid4().hex, env=environment_factory(), max_steps=max_steps, + log_exceptions_immediately=log_exceptions_immediately, ) ) for idx in range(batch_size) @@ -182,6 +200,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): traj_id=uuid.uuid4().hex, env=environment_factory(), max_steps=remaining_steps, + log_exceptions_immediately=log_exceptions_immediately, ) ) new_tasks.append(new_task) @@ -194,6 +213,8 @@ async def _sample_trajectories_from_envs( self, environments: Sequence[Environment], max_steps: int | None = None, + *, + log_exceptions_immediately: bool = False, ) -> list[Trajectory]: self.traj_buffer.clear() exception_counter: Counter = Counter() @@ -202,7 +223,14 @@ async def _sample_trajectories_from_envs( # Create all tasks first tasks = [ - asyncio.create_task(self._rollout(traj_id, env, max_steps=max_steps)) + asyncio.create_task( + self._rollout( + traj_id, + env, + max_steps=max_steps, + log_exceptions_immediately=log_exceptions_immediately + ) + ) for traj_id, env in zip(traj_ids, environments, strict=True) ] @@ -229,10 +257,12 @@ async def _sample_trajectories_from_envs( # Final summary of exceptions (if any) if exception_counter: - logger.info("Caught exceptions:") - logger.info("%-6s %-50s", "Count", "Exception") - for exc, count in exception_counter.items(): - logger.info("%-6d %-50s", count, exc) + summary = ["Caught exceptions:", "Count Exception"] + summary.extend( + f"{count:<6d} {exc:<50s}" + for exc, count in exception_counter.items() + ) + logger.info("\n".join(summary)) return [self.traj_buffer[traj_id] for traj_id in traj_ids] @@ -294,6 +324,9 @@ async def store_step(step: Transition): except CaughtError as e: # NOTE: This trajectory should not be used for regular training. # We save the last transition here for debugging, etc. + if log_exceptions_immediately: + logger.exception(f"Exception in rollout {traj_id}: {e.original_exc}") + await store_step( Transition( timestep=len(trajectory.steps), diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index f6597bb9..55ef0e9b 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -120,12 +120,6 @@ async def _maybe_process_batch(self) -> None: If neither condition is met, do nothing. """ - # Technically should not happen, but if a coroutine crashes, it could release - # self._lock before placing results in _results_buffer and additional process - # coming inside this func will crash as self._work_buffer will be empty. - if not self._work_buffer: - return - now = time.time() # sort by oldest requests first diff --git a/src/ldp/nn/handlers/chunking.py b/src/ldp/nn/handlers/chunking.py index 147b6dd8..9369d6d3 100644 --- a/src/ldp/nn/handlers/chunking.py +++ b/src/ldp/nn/handlers/chunking.py @@ -9,9 +9,8 @@ class TensorChunker: """Splits tensors into chunks and adds dummy chunks as needed for parallel processing frameworks like FSDP.""" - def __init__(self, num_chunks: int, dummy_value: int = 0): + def __init__(self, num_chunks: int): self.num_chunks = num_chunks - self.dummy_value = dummy_value def chunkify(self, *args, **kwargs) -> tuple[list[tuple], list[dict], list[bool]]: """Splits the args into self.num_chunks chunks, adding dummy chunks as needed. diff --git a/tests/test_nn_models.py b/tests/test_nn_models.py index aa717b48..e872d9f1 100644 --- a/tests/test_nn_models.py +++ b/tests/test_nn_models.py @@ -27,11 +27,10 @@ class TestTensorChunker: def test_chunkify_add_dummy_chunks(self): batch_size = 3 num_chunks = 5 - dummy_value = 0 sample_tensor = torch.arange(1, batch_size * 10 + 1).reshape(batch_size, 10) - chunker = ldp.nn.TensorChunker(num_chunks=num_chunks, dummy_value=dummy_value) + chunker = ldp.nn.TensorChunker(num_chunks=num_chunks) split_args, split_kwargs, dummy_chunk_flags = chunker.chunkify(sample_tensor) assert len(split_args) == num_chunks @@ -41,20 +40,19 @@ def test_chunkify_add_dummy_chunks(self): assert torch.equal(split_args[1][0], sample_tensor[1:2]) assert torch.equal(split_args[2][0], sample_tensor[2:3]) assert torch.equal( - split_args[3][0], torch.full_like(sample_tensor[:1], dummy_value) + split_args[3][0], sample_tensor[:1] ) assert torch.equal( - split_args[4][0], torch.full_like(sample_tensor[:1], dummy_value) + split_args[4][0], sample_tensor[:1] ) def test_chunkify_no_dummy_chunks(self): batch_size = 9 num_chunks = 5 - dummy_value = 0 sample_tensor = torch.arange(1, batch_size * 10 + 1).reshape(batch_size, 10) - chunker = ldp.nn.TensorChunker(num_chunks=num_chunks, dummy_value=dummy_value) + chunker = ldp.nn.TensorChunker(num_chunks=num_chunks) split_args, split_kwargs, dummy_chunk_flags = chunker.chunkify(sample_tensor) assert len(split_args) == num_chunks @@ -69,7 +67,6 @@ def test_chunkify_no_dummy_chunks(self): def test_chunkify_with_args_and_kwargs(self): batch_size = 2 num_chunks = 3 - dummy_value = 0 sample_tensor = torch.arange(1, batch_size * 10 + 1).reshape(batch_size, 10) sample_tensor_kwarg = torch.arange(1, batch_size * 5 + 1).reshape(batch_size, 5) @@ -78,7 +75,7 @@ def test_chunkify_with_args_and_kwargs(self): "key2": "Not split", } - chunker = ldp.nn.TensorChunker(num_chunks=num_chunks, dummy_value=dummy_value) + chunker = ldp.nn.TensorChunker(num_chunks=num_chunks) split_args, split_kwargs, dummy_chunk_flags = chunker.chunkify( sample_tensor, **sample_kwargs ) @@ -89,13 +86,13 @@ def test_chunkify_with_args_and_kwargs(self): assert torch.equal(split_args[0][0], sample_tensor[:1]) assert torch.equal(split_args[1][0], sample_tensor[1:2]) assert torch.equal( - split_args[2][0], torch.full_like(sample_tensor[:1], dummy_value) + split_args[2][0], sample_tensor[:1] ) assert torch.equal(split_kwargs[0]["key1"], sample_tensor_kwarg[:1]) assert torch.equal(split_kwargs[1]["key1"], sample_tensor_kwarg[1:2]) assert torch.equal( split_kwargs[2]["key1"], - torch.full_like(sample_tensor_kwarg[:1], dummy_value), + sample_tensor_kwarg[:1] ) assert all(split_kwargs[i]["key2"] == "Not split" for i in range(num_chunks)) From 6ed30bd6a23cfecd60ffc07ad1cf13e11736d588 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 4 Mar 2025 13:39:32 -0600 Subject: [PATCH 10/34] nits --- src/ldp/alg/rollout.py | 94 ++++++++++------------ src/ldp/nn/agent/simple_local_agent.py | 5 +- src/ldp/nn/handlers/transformer_handler.py | 5 +- tests/test_nn_models.py | 17 +--- 4 files changed, 51 insertions(+), 70 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 210437de..17c96a10 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -5,7 +5,7 @@ from collections import Counter from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager, nullcontext -from typing import Any, TypeVar, overload +from typing import Any, TypeVar from aviary.core import Environment, Message from tqdm import tqdm @@ -73,64 +73,52 @@ def __init__( self.traj_buffer: dict[str, Trajectory] = {} self.callbacks = callbacks or [] - @overload - async def sample_trajectories( # noqa: D418 + async def sample_trajectories( self, - environment_factory: Callable[[], TEnv], + environment_factory: Callable[[], TEnv] | None = None, + environments: Sequence[TEnv] | None = None, batch_size: int = 1, max_steps: int | None = None, *, log_exceptions_immediately: bool = False, - ) -> list[tuple[Trajectory, TEnv]]: - """Run rollouts in parallel, using a factory to construct environments. + ) -> list[tuple[Trajectory, Environment]] | list[Trajectory]: + """Sample trajectories from environments, either via factory or pre-created. - We will construct `batch_size` environments and run rollouts on each of them. - If `max_steps` is set, rollouts will be truncated at this value. If a rollout - has fewer than `max_steps`, then a new environment will be constructed and another - rollout will be started until `max_steps` is reached. + There are two main ways to use this method: - Args: - environment_factory: A no-argument callable that returns - an environment instance - batch_size (int, optional): Defaults to 1. - max_steps (int | None, optional): Max steps per rollout. Defaults to None (see above). - log_exceptions_immediately (bool, optional): Whether to log exceptions as they occur - or only at the end of the rollouts. Defaults to False. + 1. Using an environment factory: + Run rollouts in parallel, using a factory to construct environments. + We will construct `batch_size` environments and run rollouts on each of them. + If `max_steps` is set, rollouts will be truncated at this value. If a rollout + has fewer than `max_steps`, then a new environment will be constructed and another + rollout will be started until `max_steps` is reached. - Returns: - list[tuple[Trajectory, Environment]]: A list of (trajectory, environment) tuples: one per rollout. - """ + In this case, returns a list of (trajectory, environment) tuples. - @overload - async def sample_trajectories( # noqa: D418 - self, - environments: Sequence[Environment], - max_steps: int | None = None, - *, - log_exceptions_immediately: bool = False, - ) -> list[Trajectory]: - """Run rollouts in parallel on a list of provided environments. + 2. Using a sequence of environments: + Run rollouts in parallel on a list of provided environments. + In this case, returns a list of trajectories. Args: - environments: A list of environments to run rollouts on. + environment_factory: A no-argument callable that returns an environment instance + environments: A list of environments to run rollouts on + batch_size: Number of parallel environments to run when using environment_factory. Defaults to 1. max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. - log_exceptions_immediately (bool, optional): Whether to log exceptions as they occur + log_exceptions_immediately: Whether to log exceptions as they occur or only at the end of the rollouts. Defaults to False. - """ - async def sample_trajectories( - self, - environment_factory: Callable[[], Environment] | None = None, - environments: Sequence[Environment] | None = None, - batch_size: int = 1, - max_steps: int | None = None, - *, - log_exceptions_immediately: bool = False, - ) -> list[tuple[Trajectory, Environment]] | list[Trajectory]: - """Sample trajectories from environments, either via factory or pre-created.""" + Returns: + Either list[tuple[Trajectory, Environment]] or list[Trajectory] depending on whether + environment_factory or environments is provided. + + Raises: + TypeError: If neither environment_factory nor environments is provided. + """ if environment_factory is not None: - assert environments is None, "Cannot use environment_factory with environments" + assert environments is None, ( + "Cannot use environment_factory with environments" + ) return await self._sample_trajectories_from_env_factory( environment_factory, batch_size, @@ -139,9 +127,11 @@ async def sample_trajectories( ) if environments is not None: - assert environment_factory is None, "Cannot use environments with environment_factory" + assert environment_factory is None, ( + "Cannot use environments with environment_factory" + ) return await self._sample_trajectories_from_envs( - environments, + environments, max_steps, log_exceptions_immediately=log_exceptions_immediately, ) @@ -225,10 +215,10 @@ async def _sample_trajectories_from_envs( tasks = [ asyncio.create_task( self._rollout( - traj_id, - env, + traj_id, + env, max_steps=max_steps, - log_exceptions_immediately=log_exceptions_immediately + log_exceptions_immediately=log_exceptions_immediately, ) ) for traj_id, env in zip(traj_ids, environments, strict=True) @@ -259,8 +249,7 @@ async def _sample_trajectories_from_envs( if exception_counter: summary = ["Caught exceptions:", "Count Exception"] summary.extend( - f"{count:<6d} {exc:<50s}" - for exc, count in exception_counter.items() + f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() ) logger.info("\n".join(summary)) @@ -271,7 +260,8 @@ async def _rollout( traj_id: str, env: Environment, max_steps: int | None, - max_tokens: int | None = None, # <-- new argument + *, + log_exceptions_immediately: bool = False, ) -> Trajectory: trajectory = Trajectory(traj_id=traj_id) @@ -326,7 +316,7 @@ async def store_step(step: Transition): # We save the last transition here for debugging, etc. if log_exceptions_immediately: logger.exception(f"Exception in rollout {traj_id}: {e.original_exc}") - + await store_step( Transition( timestep=len(trajectory.steps), diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 222ee7ce..00957d0d 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist from aviary.core import Message, Tool, ToolRequestMessage -from litellm import token_counter +from litellm.utils import token_counter from pydantic import Field, field_validator from ldp.agent import Agent, SimpleAgentState @@ -131,10 +131,11 @@ def _validate_token_count(self, messages: list[Message], tools: list[Tool]): return messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages) tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools) + total_tokens = token_counter( model=self.llm_model.model, messages=messages_for_tokenizer, - tools=tools_for_tokenizer, + tools=tools_for_tokenizer, # type: ignore[arg-type] ) if total_tokens > self.llm_model.max_traj_token_count: logger.error( diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 4535c9fd..537a9129 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -17,7 +17,7 @@ import torch.distributed as dist import tree from dask import config -from dask.distributed import Client, as_completed, wait +from dask.distributed import Client, as_completed from pydantic import BaseModel, ConfigDict, Field, field_validator from torch import nn from torch.cuda import nccl @@ -192,14 +192,13 @@ async def __call__( # type: ignore[override] @staticmethod def model_generate(model: PreTrainedModel, *args, **kwargs): """A method that can be used as module_call_fn to sample from an LLM.""" - if dist.get_world_size() > 1: + if int(os.environ.get("WORLD_SIZE", "1")) > 1: synced_gpus = kwargs.pop("synced_gpus", None) if synced_gpus is None: logger.debug("synced_gpus not defined, defaulting to True.") kwargs["synced_gpus"] = True elif not synced_gpus: raise ValueError("synced_gpus must be True when using FSDP.") - # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. diff --git a/tests/test_nn_models.py b/tests/test_nn_models.py index e872d9f1..55917ce3 100644 --- a/tests/test_nn_models.py +++ b/tests/test_nn_models.py @@ -39,12 +39,8 @@ def test_chunkify_add_dummy_chunks(self): assert torch.equal(split_args[0][0], sample_tensor[:1]) assert torch.equal(split_args[1][0], sample_tensor[1:2]) assert torch.equal(split_args[2][0], sample_tensor[2:3]) - assert torch.equal( - split_args[3][0], sample_tensor[:1] - ) - assert torch.equal( - split_args[4][0], sample_tensor[:1] - ) + assert torch.equal(split_args[3][0], sample_tensor[:1]) + assert torch.equal(split_args[4][0], sample_tensor[:1]) def test_chunkify_no_dummy_chunks(self): batch_size = 9 @@ -85,15 +81,10 @@ def test_chunkify_with_args_and_kwargs(self): assert dummy_chunk_flags == [False, False, True] assert torch.equal(split_args[0][0], sample_tensor[:1]) assert torch.equal(split_args[1][0], sample_tensor[1:2]) - assert torch.equal( - split_args[2][0], sample_tensor[:1] - ) + assert torch.equal(split_args[2][0], sample_tensor[:1]) assert torch.equal(split_kwargs[0]["key1"], sample_tensor_kwarg[:1]) assert torch.equal(split_kwargs[1]["key1"], sample_tensor_kwarg[1:2]) - assert torch.equal( - split_kwargs[2]["key1"], - sample_tensor_kwarg[:1] - ) + assert torch.equal(split_kwargs[2]["key1"], sample_tensor_kwarg[:1]) assert all(split_kwargs[i]["key2"] == "Not split" for i in range(num_chunks)) def test_dechunkify(self): From d245e3de0789202f818e802fa7f24f20c68971f2 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 4 Mar 2025 13:53:35 -0600 Subject: [PATCH 11/34] nits --- src/ldp/alg/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 17c96a10..9f34cf57 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -246,7 +246,7 @@ async def _sample_trajectories_from_envs( pbar.set_postfix({"num_exceptions": num_exceptions}) # Final summary of exceptions (if any) - if exception_counter: + if exception_counter and not log_exceptions_immediately: summary = ["Caught exceptions:", "Count Exception"] summary.extend( f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() From 51210dfcefa242350cfccc81609329f2425a742c Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 4 Mar 2025 14:16:33 -0600 Subject: [PATCH 12/34] nits --- src/ldp/alg/rollout.py | 88 ++++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 9f34cf57..0ab2b7b6 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -80,7 +80,7 @@ async def sample_trajectories( batch_size: int = 1, max_steps: int | None = None, *, - log_exceptions_immediately: bool = False, + log_exceptions_immediately: bool = True, ) -> list[tuple[Trajectory, Environment]] | list[Trajectory]: """Sample trajectories from environments, either via factory or pre-created. @@ -106,7 +106,7 @@ async def sample_trajectories( max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. log_exceptions_immediately: Whether to log exceptions as they occur - or only at the end of the rollouts. Defaults to False. + or only at the end of the rollouts. Returns: Either list[tuple[Trajectory, Environment]] or list[Trajectory] depending on whether @@ -147,14 +147,17 @@ async def _sample_trajectories_from_env_factory( batch_size: int = 1, max_steps: int | None = None, *, - log_exceptions_immediately: bool = False, + log_exceptions_immediately: bool = True, ) -> list[tuple[Trajectory, Environment]]: self.traj_buffer.clear() + exception_counter: Counter = Counter() async def rollout_with_args(idx: int, **rollout_kwargs): return idx, await self._rollout(**rollout_kwargs), rollout_kwargs accumulated_steps = [0] * batch_size + total_trajectories = 0 # Counter for completed trajectories + # submit initial batch of tasks tasks = [ asyncio.create_task( @@ -170,32 +173,59 @@ async def rollout_with_args(idx: int, **rollout_kwargs): ] results = [] - while tasks: - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - new_tasks = [] - for task in done: - idx, traj, kwargs = await task - results.append((traj, kwargs["env"])) - accumulated_steps[idx] += len(traj.steps) - if ( - max_steps is not None - and (remaining_steps := max_steps - accumulated_steps[idx]) > 0 - ): - # submit another task if we haven't reached max_steps - new_task = asyncio.create_task( - rollout_with_args( - idx, - traj_id=uuid.uuid4().hex, - env=environment_factory(), - max_steps=remaining_steps, - log_exceptions_immediately=log_exceptions_immediately, + with tqdm( + desc="Rollouts", + unit="rollout", + ncols=0, + ) as pbar: + while tasks: + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + new_tasks = [] + for task in done: + idx, traj, kwargs = await task + results.append((traj, kwargs["env"])) + total_trajectories += 1 + pbar.update(1) + + steps_in_traj = len(traj.steps) + accumulated_steps[idx] += steps_in_traj + + # Check for exceptions in this trajectory + if traj.steps and traj.steps[-1].metadata.get("exception"): + exc_str: str = str(traj.steps[-1].metadata["exception"])[ + :500 + ].replace('"', "'") + exception_counter[exc_str] += 1 + num_exceptions = sum(exception_counter.values()) + pbar.set_postfix({"num_exceptions": num_exceptions}) + + if ( + max_steps is not None + and (remaining_steps := max_steps - accumulated_steps[idx]) > 0 + ): + # submit another task if we haven't reached max_steps + new_task = asyncio.create_task( + rollout_with_args( + idx, + traj_id=uuid.uuid4().hex, + env=environment_factory(), + max_steps=remaining_steps, + log_exceptions_immediately=log_exceptions_immediately, + ) ) - ) - new_tasks.append(new_task) + new_tasks.append(new_task) - tasks = list(pending) + new_tasks + tasks = list(pending) + new_tasks + + # Final summary of exceptions (if any) + if exception_counter and not log_exceptions_immediately: + summary = ["Caught exceptions:", "Count Exception"] + summary.extend( + f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() + ) + logger.info("\n".join(summary)) return results @@ -204,7 +234,7 @@ async def _sample_trajectories_from_envs( environments: Sequence[Environment], max_steps: int | None = None, *, - log_exceptions_immediately: bool = False, + log_exceptions_immediately: bool = True, ) -> list[Trajectory]: self.traj_buffer.clear() exception_counter: Counter = Counter() @@ -261,7 +291,7 @@ async def _rollout( env: Environment, max_steps: int | None, *, - log_exceptions_immediately: bool = False, + log_exceptions_immediately: bool = True, ) -> Trajectory: trajectory = Trajectory(traj_id=traj_id) From 3d7c20e7aef8ea320ad1b5eaef83aadd837a1ba2 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 12 Mar 2025 03:48:39 -0500 Subject: [PATCH 13/34] Refactor Dask handling in transformer handler for improved exception management and memory efficiency --- src/ldp/alg/rollout.py | 2 +- src/ldp/nn/handlers/transformer_handler.py | 128 ++++++++++++++------- 2 files changed, 90 insertions(+), 40 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 0ab2b7b6..da40f669 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -8,7 +8,7 @@ from typing import Any, TypeVar from aviary.core import Environment, Message -from tqdm import tqdm +from tqdm.asyncio import tqdm from ldp.agent import Agent from ldp.data_structures import Trajectory, Transition diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 537a9129..6ee5ae54 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import atexit import logging import os @@ -10,14 +11,15 @@ from enum import StrEnum, auto from functools import cache, partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, cast +from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never import accelerate import torch import torch.distributed as dist import tree from dask import config -from dask.distributed import Client, as_completed +from dask.distributed import Actor, ActorFuture, Client +from distributed.utils import sync from pydantic import BaseModel, ConfigDict, Field, field_validator from torch import nn from torch.cuda import nccl @@ -199,6 +201,7 @@ def model_generate(model: PreTrainedModel, *args, **kwargs): kwargs["synced_gpus"] = True elif not synced_gpus: raise ValueError("synced_gpus must be True when using FSDP.") + raise torch.OutOfMemoryError("yoyoyoyoyoyoyo test test test TODO") # TODO remove # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. @@ -425,22 +428,29 @@ def _exec_func( args = tree.map_structure(to_device, args) kwargs = tree.map_structure(to_device, kwargs) - with torch.autocast( - device_type=self.module.device.type, dtype=self.module.dtype - ): - res = ( - getattr(self, func)(*args, **kwargs) - if isinstance(func, str) - else func(self, *args, **kwargs) - ) + try: + with torch.autocast( + device_type=self.module.device.type, dtype=self.module.dtype + ): + res = ( + getattr(self, func)(*args, **kwargs) + if isinstance(func, str) + else func(self, *args, **kwargs) + ) - # Needed to prevent GPU memory leak to the main process scheduling the workers - if isinstance(res, GenerateDecoderOnlyOutput): - res.past_key_values = None - res["past_key_values"] = None + # Needed to prevent GPU memory leak to the main process scheduling the workers + if isinstance(res, GenerateDecoderOnlyOutput): + res.past_key_values = None + res["past_key_values"] = None - to_cpu = partial(_move_tensor, device=torch.device("cpu")) - return tree.map_structure(to_cpu, res) + to_cpu = partial(_move_tensor, device=torch.device("cpu")) + return tree.map_structure(to_cpu, res) + except Exception as e: + # Re-raise the exception with traceback preserved. For some exceptions, Dask + # modifies or loses the original traceback when crossing process boundaries. + # RuntimeError preserves the traceback when using with_traceback() of original + # exception. + raise RuntimeError(str(e)).with_traceback(e.__traceback__) # noqa: B904 def __del__(self) -> None: dist.destroy_process_group() @@ -582,7 +592,7 @@ def get_cuda_visible_devices() -> int | None: futures.append(future_op) worker_ids.append(worker_id) - self.handlers = self.client_gather(futures) + self.actors: list[Actor] = self._client_gather(futures) self.worker_ids = worker_ids async def __call__( @@ -633,28 +643,24 @@ def _submit_and_gather( """ if split_data: chunker = TensorChunker( - num_chunks=len(self.handlers), + num_chunks=len(self.actors), ) split_args, split_kwargs, dummy_flags = chunker.chunkify(*args, **kwargs) else: - split_args = [args] * len(self.handlers) - split_kwargs = [kwargs] * len(self.handlers) + split_args = [args] * len(self.actors) + split_kwargs = [kwargs] * len(self.actors) futures = [ - self.client.submit( - handler._exec_func, + handler._exec_func( func, *args_i, - workers=[worker_id], - actor=True, **kwargs_i, ) for handler, worker_id, args_i, kwargs_i in zip( - self.handlers, self.worker_ids, split_args, split_kwargs, strict=True + self.actors, self.worker_ids, split_args, split_kwargs, strict=True ) ] - results = self.client_gather(futures) - results = cast("list[TReturn]", [res.result().result() for res in results]) + results: list[TReturn] = self._client_gather(futures) if split_data: return chunker.dechunkify(results, dummy_flags) @@ -767,29 +773,73 @@ def teardown(self) -> None: if self._initialized: self.client.shutdown() self.cluster.close() + del self.client + del self.cluster self._initialized = False def __del__(self) -> None: self.teardown() - def client_gather(self, futures): + @staticmethod + def _wrap_dask_future(dask_future: ActorFuture): + """Converts a Dask ActorFuture into an awaitable asyncio.Future.""" + loop = asyncio.get_running_loop() + return asyncio.ensure_future(loop.run_in_executor(None, dask_future.result)) + + def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: """Gather results from futures, propagating exceptions as they arrive. Unlike client.gather() which waits for all futures to complete before raising any exceptions, this method processes futures as they complete and raises exceptions immediately. This is crucial when using FSDP where workers may - be stuck waiting for each other where one worker crashes, causing long hangs. + be stuck waiting for each other when one worker crashes, causing long hangs. + + Note: Dask Actors currently have an issue where they're not working properly with + dask.gather() and can cause blocking issues or hide worker errors. This implementation + works around those limitations. """ - # Initialize a list to hold results - results = [None] * len(futures) - for completed_future, result in as_completed( - futures, with_results=True, raise_errors=True - ): - # Find the index of the completed future - index = futures.index(completed_future) - # Store the result directly from as_completed - results[index] = result - return results + + async def _gather_with_exception_handling(futures): + wrapped_futures = [self._wrap_dask_future(f) for f in futures] + + try: + # Use asyncio.wait with FIRST_EXCEPTION instead of gather + done, pending = await asyncio.wait( + wrapped_futures, timeout=120, return_when=asyncio.FIRST_EXCEPTION + ) + + exceptions = [] + for future in done: + exc = future.exception() + if exc: + exceptions.append(exc) + if exceptions: + if len(exceptions) == 1: + raise exceptions[0] + raise ExceptionGroup("Multiple actor exceptions", exceptions) + + if pending: + pending_indices = sorted([ + wrapped_futures.index(p) for p in pending + ]) + raise TimeoutError( + f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " + f"still pending. Pending task indices: {pending_indices}" + ) + + return await asyncio.gather(*wrapped_futures) + except Exception as e: + logger.exception("Error in dask workers") + for f in wrapped_futures: + if not f.done(): + f.cancel() + self.teardown() + # sys.exit(1) would wait for dask to finish, which can cause hanging + # when workers are in a deadlock. Use os._exit to force immediate termination + os._exit(1) + + # Use distributed.utils.sync to run the async function in the current thread + return sync(self.client.loop, _gather_with_exception_handling, futures) # type: ignore[arg-type] # Helpers From 15b936d1f12a211d67b98451c62e03b2bf3a9e81 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 12 Mar 2025 05:49:12 -0500 Subject: [PATCH 14/34] Remove test OutOfMemoryError raise in AsyncTransformerInterface --- src/ldp/nn/handlers/transformer_handler.py | 46 ++++++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 6ee5ae54..ddf159f8 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -201,7 +201,6 @@ def model_generate(model: PreTrainedModel, *args, **kwargs): kwargs["synced_gpus"] = True elif not synced_gpus: raise ValueError("synced_gpus must be True when using FSDP.") - raise torch.OutOfMemoryError("yoyoyoyoyoyoyo test test test TODO") # TODO remove # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. @@ -240,14 +239,22 @@ def __init__(self, config: TransformerHandlerConfig): assert_never(config.lm_type) super().__init__(model) self.tokenizer = tokenizer + logger.info( + f"Initialized tokenizer: {type(tokenizer).__name__}, vocab size: {len(tokenizer)}" + ) + maybe_set_tokenizer_chat_template( self.tokenizer, self.config.lm_config.chat_template ) + logger.info(f"Chat template: {self.config.lm_config.chat_template}") self._setup_accelerator() + logger.info(f"Accelerator set up with device: {self.accelerator.device}") if config.checkpoint is not None: + logger.info(f"Loading checkpoint from: {config.checkpoint}") self.load_checkpoint(config.checkpoint) + logger.info("Initialization complete") def _setup_accelerator(self): self.accelerator = accelerate.Accelerator( @@ -391,6 +398,13 @@ def _setup_accelerator(self): buffer_dtype=torch.bfloat16, ) + logger.info(f"Setting up accelerator with bf16={bf16}") + logger.info(f"Worker config: offload_cpu={self.worker_config.offload_cpu}, " + f"activation_checkpointing={self.worker_config.activation_checkpointing}, " + f"cpu_ram_efficient_loading={self.worker_config.cpu_ram_efficient_loading}, " + f"state_dict_type={self.worker_config.state_dict_type}, " + f"backward_prefetch={self.worker_config.backward_prefetch}") + self.accelerator = accelerate.Accelerator( # See note in TransformerHandler._setup_accelerator() about this # mixed_precision=("bf16" if bf16 else "no"), @@ -406,11 +420,16 @@ def _setup_accelerator(self): backward_prefetch=self.worker_config.backward_prefetch, ), ) + logger.info(f"Accelerator setup complete on rank {self.worker_config.rank}") if self.config.lm_config.device == "meta": + logger.info(f"Preparing model for FSDP with meta device on rank {self.worker_config.rank}") self.module = prepare_model_for_fsdp_with_meta_device(self.module) + logger.info(f"Meta device preparation complete on rank {self.worker_config.rank}") + logger.info(f"Preparing model with accelerator on rank {self.worker_config.rank}") self.module = self.accelerator.prepare(self.module) + logger.info(f"Model preparation complete on rank {self.worker_config.rank}, model device: {self.module.device}, dtype: {self.module.dtype}") def set_seed(self, seed: int) -> None: """Set the seed for the current worker.""" @@ -770,6 +789,7 @@ def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: self._submit_and_gather("save_checkpoint", ckpt, **kwargs) def teardown(self) -> None: + logger.info(f"Shutting down Dask cluster, is_initialized: {self._initialized}") if self._initialized: self.client.shutdown() self.cluster.close() @@ -800,14 +820,14 @@ def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: """ async def _gather_with_exception_handling(futures): - wrapped_futures = [self._wrap_dask_future(f) for f in futures] - try: + wrapped_futures = [self._wrap_dask_future(f) for f in futures] + # Use asyncio.wait with FIRST_EXCEPTION instead of gather done, pending = await asyncio.wait( - wrapped_futures, timeout=120, return_when=asyncio.FIRST_EXCEPTION + wrapped_futures, timeout=1200, return_when=asyncio.FIRST_EXCEPTION ) - + exceptions = [] for future in done: exc = future.exception() @@ -819,9 +839,7 @@ async def _gather_with_exception_handling(futures): raise ExceptionGroup("Multiple actor exceptions", exceptions) if pending: - pending_indices = sorted([ - wrapped_futures.index(p) for p in pending - ]) + pending_indices = sorted([wrapped_futures.index(p) for p in pending]) raise TimeoutError( f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " f"still pending. Pending task indices: {pending_indices}" @@ -829,19 +847,21 @@ async def _gather_with_exception_handling(futures): return await asyncio.gather(*wrapped_futures) except Exception as e: - logger.exception("Error in dask workers") - for f in wrapped_futures: - if not f.done(): - f.cancel() + logger.exception("Error in dask workers: %s") + for future in wrapped_futures: + future.cancel() self.teardown() # sys.exit(1) would wait for dask to finish, which can cause hanging # when workers are in a deadlock. Use os._exit to force immediate termination - os._exit(1) + # TODO: this is more of a hack, we should propagate special exception that is + # not caught by the rollout manager. + os._exit(1) # Use distributed.utils.sync to run the async function in the current thread return sync(self.client.loop, _gather_with_exception_handling, futures) # type: ignore[arg-type] + # Helpers From ad375011b868d17cddd77903e448c5d72cadc5c9 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 12 Mar 2025 07:14:03 -0500 Subject: [PATCH 15/34] Refactor exception handling in ParallelAsyncTransformer for improved clarity and reliability --- src/ldp/nn/handlers/transformer_handler.py | 98 +++++++++------------- 1 file changed, 40 insertions(+), 58 deletions(-) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index ddf159f8..82e9d0e3 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -239,22 +239,15 @@ def __init__(self, config: TransformerHandlerConfig): assert_never(config.lm_type) super().__init__(model) self.tokenizer = tokenizer - logger.info( - f"Initialized tokenizer: {type(tokenizer).__name__}, vocab size: {len(tokenizer)}" - ) maybe_set_tokenizer_chat_template( self.tokenizer, self.config.lm_config.chat_template ) - logger.info(f"Chat template: {self.config.lm_config.chat_template}") self._setup_accelerator() - logger.info(f"Accelerator set up with device: {self.accelerator.device}") if config.checkpoint is not None: - logger.info(f"Loading checkpoint from: {config.checkpoint}") self.load_checkpoint(config.checkpoint) - logger.info("Initialization complete") def _setup_accelerator(self): self.accelerator = accelerate.Accelerator( @@ -398,13 +391,6 @@ def _setup_accelerator(self): buffer_dtype=torch.bfloat16, ) - logger.info(f"Setting up accelerator with bf16={bf16}") - logger.info(f"Worker config: offload_cpu={self.worker_config.offload_cpu}, " - f"activation_checkpointing={self.worker_config.activation_checkpointing}, " - f"cpu_ram_efficient_loading={self.worker_config.cpu_ram_efficient_loading}, " - f"state_dict_type={self.worker_config.state_dict_type}, " - f"backward_prefetch={self.worker_config.backward_prefetch}") - self.accelerator = accelerate.Accelerator( # See note in TransformerHandler._setup_accelerator() about this # mixed_precision=("bf16" if bf16 else "no"), @@ -420,16 +406,11 @@ def _setup_accelerator(self): backward_prefetch=self.worker_config.backward_prefetch, ), ) - logger.info(f"Accelerator setup complete on rank {self.worker_config.rank}") if self.config.lm_config.device == "meta": - logger.info(f"Preparing model for FSDP with meta device on rank {self.worker_config.rank}") self.module = prepare_model_for_fsdp_with_meta_device(self.module) - logger.info(f"Meta device preparation complete on rank {self.worker_config.rank}") - logger.info(f"Preparing model with accelerator on rank {self.worker_config.rank}") self.module = self.accelerator.prepare(self.module) - logger.info(f"Model preparation complete on rank {self.worker_config.rank}, model device: {self.module.device}, dtype: {self.module.dtype}") def set_seed(self, seed: int) -> None: """Set the seed for the current worker.""" @@ -789,7 +770,6 @@ def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: self._submit_and_gather("save_checkpoint", ckpt, **kwargs) def teardown(self) -> None: - logger.info(f"Shutting down Dask cluster, is_initialized: {self._initialized}") if self._initialized: self.client.shutdown() self.cluster.close() @@ -806,7 +786,26 @@ def _wrap_dask_future(dask_future: ActorFuture): loop = asyncio.get_running_loop() return asyncio.ensure_future(loop.run_in_executor(None, dask_future.result)) - def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: + @staticmethod + def _raise_exceptions(done, pending, wrapped_futures): + exceptions = [] + for future in done: + exc = future.exception() + if exc: + exceptions.append(exc) + if exceptions: + if len(exceptions) == 1: + raise exceptions[0] + raise ExceptionGroup("Multiple actor exceptions", exceptions) + + if pending: + pending_indices = sorted([wrapped_futures.index(p) for p in pending]) + raise TimeoutError( + f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " + f"still pending. Pending task indices: {pending_indices}" + ) + + async def _client_gather_async(self, futures): """Gather results from futures, propagating exceptions as they arrive. Unlike client.gather() which waits for all futures to complete before raising @@ -818,48 +817,31 @@ def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: dask.gather() and can cause blocking issues or hide worker errors. This implementation works around those limitations. """ + try: + wrapped_futures = [self._wrap_dask_future(f) for f in futures] - async def _gather_with_exception_handling(futures): - try: - wrapped_futures = [self._wrap_dask_future(f) for f in futures] + # Use asyncio.wait with FIRST_EXCEPTION instead of gather + done, pending = await asyncio.wait( + wrapped_futures, timeout=1200, return_when=asyncio.FIRST_EXCEPTION + ) - # Use asyncio.wait with FIRST_EXCEPTION instead of gather - done, pending = await asyncio.wait( - wrapped_futures, timeout=1200, return_when=asyncio.FIRST_EXCEPTION - ) + self._raise_exceptions(done, pending, wrapped_futures) - exceptions = [] - for future in done: - exc = future.exception() - if exc: - exceptions.append(exc) - if exceptions: - if len(exceptions) == 1: - raise exceptions[0] - raise ExceptionGroup("Multiple actor exceptions", exceptions) - - if pending: - pending_indices = sorted([wrapped_futures.index(p) for p in pending]) - raise TimeoutError( - f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " - f"still pending. Pending task indices: {pending_indices}" - ) - - return await asyncio.gather(*wrapped_futures) - except Exception as e: - logger.exception("Error in dask workers: %s") - for future in wrapped_futures: - future.cancel() - self.teardown() - # sys.exit(1) would wait for dask to finish, which can cause hanging - # when workers are in a deadlock. Use os._exit to force immediate termination - # TODO: this is more of a hack, we should propagate special exception that is - # not caught by the rollout manager. - os._exit(1) + return await asyncio.gather(*wrapped_futures) + except Exception: + logger.exception("Error in dask workers: %s") + for future in wrapped_futures: + future.cancel() + self.teardown() + # sys.exit(1) would wait for dask to finish, which can cause hanging + # when workers are in a deadlock. Use os._exit to force immediate termination + # TODO: this is more of a hack, we should propagate special exception that is + # not caught by the rollout manager. + os._exit(1) + def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: # Use distributed.utils.sync to run the async function in the current thread - return sync(self.client.loop, _gather_with_exception_handling, futures) # type: ignore[arg-type] - + return sync(self.client.loop, self._client_gather_async, futures) # type: ignore[arg-type] # Helpers From 62f42d108afa2aefce2035496f3413717a3fb7ca Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 12 Mar 2025 07:47:42 -0500 Subject: [PATCH 16/34] nits --- src/ldp/alg/rollout.py | 3 +-- src/ldp/nn/agent/simple_local_agent.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index da40f669..10dcf820 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -46,7 +46,6 @@ def reraise_exc_as(reraise: type[CaughtError], enabled: bool) -> Iterator[None]: yield except Exception as e: if enabled: - # Minimal logging instead of spamming. Detailed error stored in the trajectory's metadata. logger.debug(f"Reraising {reraise.exc_type} exception.") raise reraise(e) from None raise @@ -81,7 +80,7 @@ async def sample_trajectories( max_steps: int | None = None, *, log_exceptions_immediately: bool = True, - ) -> list[tuple[Trajectory, Environment]] | list[Trajectory]: + ): """Sample trajectories from environments, either via factory or pre-created. There are two main ways to use this method: diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 00957d0d..d7fbd86d 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -127,6 +127,7 @@ async def get_asv( return cast("OpResult[ToolRequestMessage]", result), next_state, 0.0 def _validate_token_count(self, messages: list[Message], tools: list[Tool]): + """Asserts token count for the trajectory is within the limit.""" if self.llm_model.max_traj_token_count is None: return messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages) From 1b486cfd29f73ed6c43e50bd93ddbb32814050af Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 19 Mar 2025 02:54:09 -0500 Subject: [PATCH 17/34] nits code review --- src/ldp/alg/rollout.py | 88 +++++++++++----------- src/ldp/graph/async_torch.py | 14 +++- src/ldp/nn/handlers/transformer_handler.py | 2 +- 3 files changed, 59 insertions(+), 45 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 10dcf820..74519936 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -5,13 +5,14 @@ from collections import Counter from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager, nullcontext -from typing import Any, TypeVar +from typing import Any, TypeVar, overload from aviary.core import Environment, Message from tqdm.asyncio import tqdm from ldp.agent import Agent from ldp.data_structures import Trajectory, Transition +from ldp.utils import format_error_details from .callbacks import Callback @@ -72,67 +73,65 @@ def __init__( self.traj_buffer: dict[str, Trajectory] = {} self.callbacks = callbacks or [] - async def sample_trajectories( + @overload + async def sample_trajectories( # noqa: D418 self, - environment_factory: Callable[[], TEnv] | None = None, - environments: Sequence[TEnv] | None = None, + environment_factory: Callable[[], TEnv], batch_size: int = 1, max_steps: int | None = None, - *, - log_exceptions_immediately: bool = True, - ): - """Sample trajectories from environments, either via factory or pre-created. + ) -> list[tuple[Trajectory, TEnv]]: + """Run rollouts in parallel, using a factory to construct environments. - There are two main ways to use this method: + We will construct `batch_size` environments and run rollouts on each of them. + If `max_steps` is set, rollouts will be truncated at this value. If a rollout + has fewer than `max_steps`, then a new environment will be constructed and another + rollout will be started until `max_steps` is reached. - 1. Using an environment factory: - Run rollouts in parallel, using a factory to construct environments. - We will construct `batch_size` environments and run rollouts on each of them. - If `max_steps` is set, rollouts will be truncated at this value. If a rollout - has fewer than `max_steps`, then a new environment will be constructed and another - rollout will be started until `max_steps` is reached. + Args: + environment_factory: A no-argument callable that returns + an environment instance + batch_size (int, optional): Defaults to 1. + max_steps (int | None, optional): Max steps per rollout. Defaults to None (see above). - In this case, returns a list of (trajectory, environment) tuples. + Returns: + list[tuple[Trajectory, Environment]]: A list of (trajectory, environment) tuples: one per rollout. + """ - 2. Using a sequence of environments: - Run rollouts in parallel on a list of provided environments. - In this case, returns a list of trajectories. + @overload + async def sample_trajectories( # noqa: D418 + self, + environments: Sequence[Environment], + max_steps: int | None = None, + ) -> list[Trajectory]: + """Run rollouts in parallel on a list of provided environments. Args: - environment_factory: A no-argument callable that returns an environment instance - environments: A list of environments to run rollouts on - batch_size: Number of parallel environments to run when using environment_factory. Defaults to 1. + environments: A list of environments to run rollouts on. max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. - log_exceptions_immediately: Whether to log exceptions as they occur - or only at the end of the rollouts. - - Returns: - Either list[tuple[Trajectory, Environment]] or list[Trajectory] depending on whether - environment_factory or environments is provided. - - Raises: - TypeError: If neither environment_factory nor environments is provided. """ - if environment_factory is not None: - assert environments is None, ( + + async def sample_trajectories(self, **kwargs): + if "environment_factory" in kwargs: + assert "environments" not in kwargs, ( "Cannot use environment_factory with environments" ) + return await self._sample_trajectories_from_env_factory( - environment_factory, - batch_size, - max_steps, - log_exceptions_immediately=log_exceptions_immediately, + kwargs["environment_factory"], + kwargs.get("batch_size", 1), + kwargs.get("max_steps"), + log_exceptions_immediately=kwargs.get("log_exceptions_immediately", True) ) - if environments is not None: - assert environment_factory is None, ( + if "environments" in kwargs: + assert "environment_factory" not in kwargs, ( "Cannot use environments with environment_factory" ) return await self._sample_trajectories_from_envs( - environments, - max_steps, - log_exceptions_immediately=log_exceptions_immediately, + kwargs["environments"], + kwargs.get("max_steps"), + log_exceptions_immediately=kwargs.get("log_exceptions_immediately", True), ) raise TypeError( @@ -176,6 +175,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): desc="Rollouts", unit="rollout", ncols=0, + disable=log_exceptions_immediately, ) as pbar: while tasks: done, pending = await asyncio.wait( @@ -258,6 +258,7 @@ async def _sample_trajectories_from_envs( desc="Rollouts", unit="rollout", ncols=0, + disable=log_exceptions_immediately, ) as pbar: for task in asyncio.as_completed(tasks): trajectory = await task @@ -344,7 +345,8 @@ async def store_step(step: Transition): # NOTE: This trajectory should not be used for regular training. # We save the last transition here for debugging, etc. if log_exceptions_immediately: - logger.exception(f"Exception in rollout {traj_id}: {e.original_exc}") + error_details = format_error_details(e.original_exc) + logger.exception(f"Exception in rollout {traj_id}:\n{error_details}") await store_step( Transition( diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index d0aa8274..2eb19e85 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -1,6 +1,7 @@ __all__ = ["AsyncTorchModule", "async_protect_torch_call"] import asyncio +import logging import operator import time from abc import ABC, abstractmethod @@ -19,6 +20,9 @@ "Please run `pip install ldp[nn]`." ) from None + +logger = logging.getLogger(__name__) + _TORCH_LOCK = asyncio.Lock() # Supported devices here: https://pytorch.org/docs/stable/amp.html#torch.autocast @@ -90,6 +94,7 @@ def __init__( self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] self._result_buffer: dict[UUID, Any] = {} self._lock = asyncio.Lock() + self._exception_raised: Exception | None = None async def __call__(self, **kwargs): request_id = uuid4() @@ -104,13 +109,20 @@ async def __call__(self, **kwargs): # Only one coroutine allowed in here when: # - modifying the result buffer # - modifying the work buffer + if self._exception_raised is not None: + logger.info("Exception raised in another coroutine") + raise self._exception_raised if request_id in self._result_buffer: # Our request was fulfilled by this or another coroutine! return self._result_buffer.pop(request_id) # Try to run a batch. - await self._maybe_process_batch() + try: + await self._maybe_process_batch() + except Exception as e: + self._exception_raised = e + raise # Sleep, to let another coroutine take over if it needs to await asyncio.sleep(0.0) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 82e9d0e3..7ca01b74 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -59,7 +59,7 @@ # Gives us more time to debug a downed worker. TODO: see if there are negative consequences # of having this always enabled "distributed.comm.timeouts.connect": "300s", - "distributed.comm.timeouts.tcp": "300s", + "distributed.comm.timeouts.tcp": "1200s", }) TReturn = TypeVar("TReturn") From bddaa92f11cc69cd86ff974e474f73faa6967dde Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 19 Mar 2025 06:32:37 -0500 Subject: [PATCH 18/34] nits --- src/ldp/alg/rollout.py | 8 ++++++-- src/ldp/nn/agent/simple_local_agent.py | 10 +++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 74519936..28c58450 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -121,7 +121,9 @@ async def sample_trajectories(self, **kwargs): kwargs["environment_factory"], kwargs.get("batch_size", 1), kwargs.get("max_steps"), - log_exceptions_immediately=kwargs.get("log_exceptions_immediately", True) + log_exceptions_immediately=kwargs.get( + "log_exceptions_immediately", True + ), ) if "environments" in kwargs: @@ -131,7 +133,9 @@ async def sample_trajectories(self, **kwargs): return await self._sample_trajectories_from_envs( kwargs["environments"], kwargs.get("max_steps"), - log_exceptions_immediately=kwargs.get("log_exceptions_immediately", True), + log_exceptions_immediately=kwargs.get( + "log_exceptions_immediately", True + ), ) raise TypeError( diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index d7fbd86d..82d5f5ec 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -46,7 +46,7 @@ class AgentLMConfig(_LMConfig): ), validate_default=True, ) - max_traj_token_count: int | None = Field( + max_messages_token_count: int | None = Field( default=None, description="If set, raise an error if the total tokens in the trajectory exceed this value.", ) @@ -128,7 +128,7 @@ async def get_asv( def _validate_token_count(self, messages: list[Message], tools: list[Tool]): """Asserts token count for the trajectory is within the limit.""" - if self.llm_model.max_traj_token_count is None: + if self.llm_model.max_messages_token_count is None: return messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages) tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools) @@ -138,12 +138,12 @@ def _validate_token_count(self, messages: list[Message], tools: list[Tool]): messages=messages_for_tokenizer, tools=tools_for_tokenizer, # type: ignore[arg-type] ) - if total_tokens > self.llm_model.max_traj_token_count: + if total_tokens > self.llm_model.max_messages_token_count: logger.error( - f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}" ) raise ValueError( - f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}" + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}" ) # TODO: maybe remove these recomputation methods. I added them to debug some things. But idk, From 310131eab6a8bf461fec75f0672cdebc6c78fbe2 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 19 Mar 2025 06:57:37 -0500 Subject: [PATCH 19/34] nits --- src/ldp/alg/rollout.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 28c58450..e6645049 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -109,6 +109,9 @@ async def sample_trajectories( # noqa: D418 environments: A list of environments to run rollouts on. max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. + log_exceptions_immediately: Whether to log exceptions in the rollout immediately + to the console. Defaults to True. If False, progress bar will show and a summary + will be logged after all rollouts are complete. """ async def sample_trajectories(self, **kwargs): From 4e4090890f70d52bbe2539bf6c5f7cc2d81bdd37 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 19 Mar 2025 06:59:33 -0500 Subject: [PATCH 20/34] nits --- src/ldp/graph/async_torch.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index 2eb19e85..f9f76626 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -109,10 +109,6 @@ async def __call__(self, **kwargs): # Only one coroutine allowed in here when: # - modifying the result buffer # - modifying the work buffer - if self._exception_raised is not None: - logger.info("Exception raised in another coroutine") - raise self._exception_raised - if request_id in self._result_buffer: # Our request was fulfilled by this or another coroutine! return self._result_buffer.pop(request_id) From 5fb76d452ed2b785c7d9772de2f17e2cc7aa2c86 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 19 Mar 2025 08:03:34 -0500 Subject: [PATCH 21/34] nit --- src/ldp/nn/handlers/transformer_handler.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 7ca01b74..42a74637 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -194,14 +194,6 @@ async def __call__( # type: ignore[override] @staticmethod def model_generate(model: PreTrainedModel, *args, **kwargs): """A method that can be used as module_call_fn to sample from an LLM.""" - if int(os.environ.get("WORLD_SIZE", "1")) > 1: - synced_gpus = kwargs.pop("synced_gpus", None) - if synced_gpus is None: - logger.debug("synced_gpus not defined, defaulting to True.") - kwargs["synced_gpus"] = True - elif not synced_gpus: - raise ValueError("synced_gpus must be True when using FSDP.") - # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. with FullyShardedDataParallel.summon_full_params(model, recurse=False): @@ -239,7 +231,6 @@ def __init__(self, config: TransformerHandlerConfig): assert_never(config.lm_type) super().__init__(model) self.tokenizer = tokenizer - maybe_set_tokenizer_chat_template( self.tokenizer, self.config.lm_config.chat_template ) From e6ee0902ef525a6c405d08caddaddbb9ee4f7f18 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 31 Mar 2025 04:43:06 -0500 Subject: [PATCH 22/34] nits comments fix --- src/ldp/alg/rollout.py | 37 +++++++++++++------------- src/ldp/graph/async_torch.py | 3 --- src/ldp/nn/agent/simple_local_agent.py | 2 -- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index e6645049..105dc173 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -109,9 +109,9 @@ async def sample_trajectories( # noqa: D418 environments: A list of environments to run rollouts on. max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. - log_exceptions_immediately: Whether to log exceptions in the rollout immediately - to the console. Defaults to True. If False, progress bar will show and a summary - will be logged after all rollouts are complete. + summarize_exceptions: Whether to collect exceptions and show a summary at the end. + Defaults to True. If False, exceptions will be logged immediately as they occur + during rollout. """ async def sample_trajectories(self, **kwargs): @@ -124,8 +124,8 @@ async def sample_trajectories(self, **kwargs): kwargs["environment_factory"], kwargs.get("batch_size", 1), kwargs.get("max_steps"), - log_exceptions_immediately=kwargs.get( - "log_exceptions_immediately", True + summarize_exceptions=kwargs.get( + "summarize_exceptions", False ), ) @@ -136,8 +136,9 @@ async def sample_trajectories(self, **kwargs): return await self._sample_trajectories_from_envs( kwargs["environments"], kwargs.get("max_steps"), - log_exceptions_immediately=kwargs.get( - "log_exceptions_immediately", True + summarize_exceptions=kwargs.get( + "summarize_exceptions", + False, ), ) @@ -152,7 +153,7 @@ async def _sample_trajectories_from_env_factory( batch_size: int = 1, max_steps: int | None = None, *, - log_exceptions_immediately: bool = True, + summarize_exceptions: bool = False, ) -> list[tuple[Trajectory, Environment]]: self.traj_buffer.clear() exception_counter: Counter = Counter() @@ -171,7 +172,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): traj_id=uuid.uuid4().hex, env=environment_factory(), max_steps=max_steps, - log_exceptions_immediately=log_exceptions_immediately, + summarize_exceptions=summarize_exceptions, ) ) for idx in range(batch_size) @@ -182,7 +183,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): desc="Rollouts", unit="rollout", ncols=0, - disable=log_exceptions_immediately, + disable=not summarize_exceptions, ) as pbar: while tasks: done, pending = await asyncio.wait( @@ -218,7 +219,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): traj_id=uuid.uuid4().hex, env=environment_factory(), max_steps=remaining_steps, - log_exceptions_immediately=log_exceptions_immediately, + summarize_exceptions=summarize_exceptions, ) ) new_tasks.append(new_task) @@ -226,7 +227,7 @@ async def rollout_with_args(idx: int, **rollout_kwargs): tasks = list(pending) + new_tasks # Final summary of exceptions (if any) - if exception_counter and not log_exceptions_immediately: + if exception_counter and summarize_exceptions: summary = ["Caught exceptions:", "Count Exception"] summary.extend( f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() @@ -240,7 +241,7 @@ async def _sample_trajectories_from_envs( environments: Sequence[Environment], max_steps: int | None = None, *, - log_exceptions_immediately: bool = True, + summarize_exceptions: bool = False, ) -> list[Trajectory]: self.traj_buffer.clear() exception_counter: Counter = Counter() @@ -254,7 +255,7 @@ async def _sample_trajectories_from_envs( traj_id, env, max_steps=max_steps, - log_exceptions_immediately=log_exceptions_immediately, + summarize_exceptions=summarize_exceptions, ) ) for traj_id, env in zip(traj_ids, environments, strict=True) @@ -265,7 +266,7 @@ async def _sample_trajectories_from_envs( desc="Rollouts", unit="rollout", ncols=0, - disable=log_exceptions_immediately, + disable=not summarize_exceptions, ) as pbar: for task in asyncio.as_completed(tasks): trajectory = await task @@ -283,7 +284,7 @@ async def _sample_trajectories_from_envs( pbar.set_postfix({"num_exceptions": num_exceptions}) # Final summary of exceptions (if any) - if exception_counter and not log_exceptions_immediately: + if exception_counter and summarize_exceptions: summary = ["Caught exceptions:", "Count Exception"] summary.extend( f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() @@ -298,7 +299,7 @@ async def _rollout( env: Environment, max_steps: int | None, *, - log_exceptions_immediately: bool = True, + summarize_exceptions: bool = False, ) -> Trajectory: trajectory = Trajectory(traj_id=traj_id) @@ -351,7 +352,7 @@ async def store_step(step: Transition): except CaughtError as e: # NOTE: This trajectory should not be used for regular training. # We save the last transition here for debugging, etc. - if log_exceptions_immediately: + if not summarize_exceptions: error_details = format_error_details(e.original_exc) logger.exception(f"Exception in rollout {traj_id}:\n{error_details}") diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index f9f76626..2ef3f4a9 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -20,9 +20,6 @@ "Please run `pip install ldp[nn]`." ) from None - -logger = logging.getLogger(__name__) - _TORCH_LOCK = asyncio.Lock() # Supported devices here: https://pytorch.org/docs/stable/amp.html#torch.autocast diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 82d5f5ec..1383a51b 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -122,8 +122,6 @@ async def get_asv( # Update state messages with result and return the new state next_state.messages = [*next_state.messages, result.value] - self._validate_token_count(next_state.messages, next_state.tools) - return cast("OpResult[ToolRequestMessage]", result), next_state, 0.0 def _validate_token_count(self, messages: list[Message], tools: list[Tool]): From 5651e54ca119a823481ef546e848025cf87c6eef Mon Sep 17 00:00:00 2001 From: kwanUm Date: Mon, 31 Mar 2025 12:45:34 +0300 Subject: [PATCH 23/34] Update src/ldp/nn/agent/simple_local_agent.py Co-authored-by: Siddharth Narayanan --- src/ldp/nn/agent/simple_local_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 1383a51b..858ba768 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -48,7 +48,7 @@ class AgentLMConfig(_LMConfig): ) max_messages_token_count: int | None = Field( default=None, - description="If set, raise an error if the total tokens in the trajectory exceed this value.", + description="If set, raise an error if the total tokens in the message history and tool description exceed this value.", ) @field_validator("llm_call_kwargs") From eb245199632637aac4da0d74ef9774e17c422fed Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 31 Mar 2025 04:48:33 -0500 Subject: [PATCH 24/34] nits --- src/ldp/nn/agent/simple_local_agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 858ba768..86f17332 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -46,7 +46,7 @@ class AgentLMConfig(_LMConfig): ), validate_default=True, ) - max_messages_token_count: int | None = Field( + max_token_count: int | None = Field( default=None, description="If set, raise an error if the total tokens in the message history and tool description exceed this value.", ) @@ -126,7 +126,7 @@ async def get_asv( def _validate_token_count(self, messages: list[Message], tools: list[Tool]): """Asserts token count for the trajectory is within the limit.""" - if self.llm_model.max_messages_token_count is None: + if self.llm_model.max_token_count is None: return messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages) tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools) @@ -136,12 +136,12 @@ def _validate_token_count(self, messages: list[Message], tools: list[Tool]): messages=messages_for_tokenizer, tools=tools_for_tokenizer, # type: ignore[arg-type] ) - if total_tokens > self.llm_model.max_messages_token_count: + if total_tokens > self.llm_model.max_token_count: logger.error( - f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}" + f"Token limit exceeded: {total_tokens} > {self.llm_model.max_token_count}" ) raise ValueError( - f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}" + f"Token limit exceeded: {total_tokens} > {self.llm_model.max_token_count}" ) # TODO: maybe remove these recomputation methods. I added them to debug some things. But idk, From f7a8247b0b3c933cbbbc4d4f0413321f5c8db3ab Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 31 Mar 2025 04:51:13 -0500 Subject: [PATCH 25/34] nits --- src/ldp/alg/rollout.py | 4 +--- src/ldp/graph/async_torch.py | 1 - src/ldp/nn/agent/simple_local_agent.py | 5 ++++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index 105dc173..385ba5c6 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -124,9 +124,7 @@ async def sample_trajectories(self, **kwargs): kwargs["environment_factory"], kwargs.get("batch_size", 1), kwargs.get("max_steps"), - summarize_exceptions=kwargs.get( - "summarize_exceptions", False - ), + summarize_exceptions=kwargs.get("summarize_exceptions", False), ) if "environments" in kwargs: diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index 2ef3f4a9..ef8fa51a 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -1,7 +1,6 @@ __all__ = ["AsyncTorchModule", "async_protect_torch_call"] import asyncio -import logging import operator import time from abc import ABC, abstractmethod diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 86f17332..bf173b2a 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -48,7 +48,10 @@ class AgentLMConfig(_LMConfig): ) max_token_count: int | None = Field( default=None, - description="If set, raise an error if the total tokens in the message history and tool description exceed this value.", + description=( + "If set, raise an error if the total tokens in the message history " + "and tool description exceed this value." + ), ) @field_validator("llm_call_kwargs") From decced05a9a78026178afa8a55c494ddbd486478 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 7 Apr 2025 09:50:05 -0500 Subject: [PATCH 26/34] fsdp2 support + slurm support --- src/ldp/nn/graph/llm_call_op.py | 2 - src/ldp/nn/handlers/README_FSDP2.md | 23 + src/ldp/nn/handlers/transformer_handler.py | 158 +++- .../nn/handlers/transformer_handler_fsdp2.py | 707 ++++++++++++++++++ tests/test_nn_models.py | 4 - 5 files changed, 860 insertions(+), 34 deletions(-) create mode 100644 src/ldp/nn/handlers/README_FSDP2.md create mode 100644 src/ldp/nn/handlers/transformer_handler_fsdp2.py diff --git a/src/ldp/nn/graph/llm_call_op.py b/src/ldp/nn/graph/llm_call_op.py index cf68a5f6..35e82ef5 100644 --- a/src/ldp/nn/graph/llm_call_op.py +++ b/src/ldp/nn/graph/llm_call_op.py @@ -12,7 +12,6 @@ from ldp.graph.op_utils import CallID, get_call_id, get_training_mode from ldp.graph.ops import GradInType, Op, OpCtx, ResultOrValue from ldp.nn.handlers.transformer_handler import ( - AsyncTransformerInterface, LMType, ParallelModeConfig, TransformerHandlerConfig, @@ -53,7 +52,6 @@ def __init__( parallel_mode_config=parallel_mode_config, # constant configuration lm_type=LMType.GENERATION, - module_call_fn=AsyncTransformerInterface.model_generate, collate_fn=partial( collate_fn_transformer_left_pad, pad_token_id=pad_token_id ), diff --git a/src/ldp/nn/handlers/README_FSDP2.md b/src/ldp/nn/handlers/README_FSDP2.md new file mode 100644 index 00000000..9bfea2ea --- /dev/null +++ b/src/ldp/nn/handlers/README_FSDP2.md @@ -0,0 +1,23 @@ +# FSDP2 Implementation for Transformer Handler + +This implementation replaces the Accelerate-based FSDP implementation with PyTorch's native FSDP2 API. + +## Key Changes + +1. **Direct use of FSDP2 APIs**: + + - Uses `fully_shard()` from `torch.distributed.fsdp.fully_shard` instead of Accelerate's wrapper + - Registers model methods with `register_fsdp_forward_method` to ensure proper handling of model.generate() + +2. **Simplified Configuration**: + + - Uses native FSDP2 policies such as `MixedPrecisionPolicy` and `OffloadPolicy` + - Removed dependency on Accelerate-specific config formats + +3. **State Dict Management**: + + - With FSDP2, state dicts contain DTensors, which can be converted to full tensors when needed + - Added utility to consolidate DTensor state dicts for checkpointing + +4. **Code Reuse**: + - Imports utility functions when possible from the original `transformer_handler.py` file diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 42a74637..ac600352 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -6,12 +6,13 @@ import os import socket import sys +import time from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from enum import StrEnum, auto from functools import cache, partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never +from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, cast import accelerate import torch @@ -107,6 +108,11 @@ class FSDPConfig(BaseModel): " is PRE to be consistent with FSDP's default." ), ) + # FSDP2 specific settings + reshard_after_forward: bool = Field( + default=True, + description="Whether to free the full parameters after forward computation.", + ) @field_validator("backward_prefetch", mode="before") @classmethod @@ -131,7 +137,10 @@ class ParallelModeConfig(FSDPConfig): ), default=ExecutionMode.LOCAL_MACHINE, ) - + num_gpus_per_node: int = Field( + default=8, + description="Number of GPUs per node. Defaults to 8 for standard GPU nodes.", + ) scheduler_addr: str = "localhost" scheduler_port: int = Field(default=0, description="0 means Dask picks randomly.") torch_port: int = Field(default_factory=get_unused_port) @@ -140,7 +149,9 @@ class ParallelModeConfig(FSDPConfig): walltime: str = Field( default="00:30:00", description="Max time the worker can run." ) - memory: str = Field(default="32GB", description="Memory allocated per worker.") + memory_per_worker: str = Field( + default="32GB", description="Memory allocated per worker." + ) log_directory: str = Field( default=f"{REPO_ROOT}/logs/slurm_outputs/", description="Directory to store logs.", @@ -152,16 +163,25 @@ class LMType(StrEnum): REGRESSION = auto() +class TransformerImplementation(StrEnum): + ACCELERATOR = auto() # Current implementation using Accelerator + FSDP2 = auto() # New implementation using FSDP2 + + class TransformerHandlerConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") lm_config: LMConfig lm_type: LMType checkpoint: str | None = None + implementation: TransformerImplementation = Field( + default=TransformerImplementation.ACCELERATOR, + description="Which transformer implementation to use (Accelerator or FSDP2)", + ) batch_size: int max_wait_interval: float = 0.1 - module_call_fn: Callable + module_call_fn: Callable | None = None collate_fn: Callable decollate_fn: Callable @@ -172,10 +192,37 @@ class TransformerHandlerConfig(BaseModel): "multiple devices/nodes. If not provided, will default to single-device." ), ) + + @field_validator("module_call_fn", mode="before") + def set_default_module_call_fn(cls, v, info): + """Set default module_call_fn based on implementation if not provided.""" + if v is None: + # Get the implementation value from the validation context + implementation = info.data.get("implementation", TransformerImplementation.ACCELERATOR) + + # For now, both implementations use the same model_generate function + # but this allows for future differentiation based on implementation + if implementation == TransformerImplementation.FSDP2: + from .transformer_handler_fsdp2 import AsyncTransformerInterface as FSDP2AsyncTransformerInterface + return FSDP2AsyncTransformerInterface.model_generate + else: # Default or ACCELERATOR + return AsyncTransformerInterface.model_generate + return v def make_async_module(self, **kwargs) -> AsyncTransformerInterface: if self.parallel_mode_config: - return ParallelAsyncTransformer(config=self, **kwargs) + if self.implementation == TransformerImplementation.ACCELERATOR: + return ParallelAsyncTransformer(config=self, **kwargs) + if self.implementation == TransformerImplementation.FSDP2: + from .transformer_handler_fsdp2 import ( + ParallelTransformerHandler as FSDP2ParallelTransformerHandler, + ) + + return cast( + "ParallelAsyncTransformer", + FSDP2ParallelTransformerHandler(config=self, **kwargs), + ) + raise ValueError(f"Unsupported implementation: {self.implementation}") return AsyncTransformer(config=self, **kwargs) @@ -255,11 +302,16 @@ def local_rank(self) -> int: def load_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: logger.info(f'Loading checkpoint from "{ckpt}"') + start_time = time.perf_counter() self.accelerator.load_state(str(ckpt), **kwargs) + self.barrier() + logger.info(f"Loading checkpoint took {time.perf_counter() - start_time:.2f}s") def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + start_time = time.perf_counter() self.accelerator.save_state(str(ckpt), **kwargs) self.barrier() + logger.info(f"Saving checkpoint took {time.perf_counter() - start_time:.2f}s") # We do not want to save random states - they would be loaded by load_state # automatically. Clean up after all processes have saved. if int(os.getenv("RANK", "0")) == 0: @@ -290,7 +342,6 @@ def __init__(self, config: TransformerHandlerConfig): max_wait_interval=config.max_wait_interval, collate_fn=config.collate_fn, decollate_fn=config.decollate_fn, - module_call_fn=config.module_call_fn, ) async def __call__( @@ -336,7 +387,6 @@ class ParallelWorkerConfig(FSDPConfig): def set_env_vars(self): # These inform torch.distributed how to set up the process group - os.environ["CUDA_VISIBLE_DEVICES"] = str(self.local_rank) os.environ["RANK"] = str(self.rank) os.environ["WORLD_SIZE"] = str(self.world_size) os.environ["LOCAL_RANK"] = str(self.local_rank) @@ -348,6 +398,7 @@ def set_env_vars(self): os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str( int(self.cpu_ram_efficient_loading) ) + os.environ["ACCELERATE_TORCH_DEVICE"] = f"cuda:{self.local_rank}" class ParallelTransformerHandler(TransformerHandler): @@ -357,7 +408,9 @@ def __init__( parallel_worker_config: ParallelWorkerConfig, ): parallel_worker_config.set_env_vars() - dist.init_process_group(backend="nccl") + dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}")) + torch.cuda.set_device(self.local_rank) + dist.barrier() self.worker_config = parallel_worker_config super().__init__(config) @@ -368,7 +421,7 @@ def _setup_accelerator(self): if bf16: bf16_ready = ( torch.version.cuda - and torch.cuda.is_bf16_supported() + # and torch.cuda.is_bf16_supported() # TODO add it back and torch.version.cuda >= "11.0" and dist.is_nccl_available() and nccl.version() >= (2, 10) @@ -413,6 +466,7 @@ def _exec_func( *args, **kwargs, ) -> TReturn: + torch.cuda.set_device(self.local_rank) # data will be on CPU when sent from controller data_device = _get_data_device() to_device = partial(_move_tensor, device=data_device) @@ -495,43 +549,94 @@ def _init_local_cluster( # lazy import since dask-cuda only works on Linux machines from dask_cuda import LocalCUDACluster - # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection - # post initialization. This prevents issues with forked processes wrongly detecting the - # default GPU as cuda:0 - os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" + kwargs = {} + if os.environ.get("USE_UCX"): + kwargs = { + "protocol": "ucx", + "enable_tcp_over_ucx": True, + "enable_infiniband": True, + "enable_nvlink": True, + } + self.cluster = LocalCUDACluster( n_workers=parallel_mode_config.num_workers, threads_per_worker=parallel_mode_config.num_cpus_per_worker, host=parallel_mode_config.scheduler_addr, port=parallel_mode_config.scheduler_port, memory_limit=None, # do not let Dask manage memory - if we OOM, we OOM + device_memory_limit=0, # Disable gpu memory spilling. Should be handled by FSDP + **kwargs, ) + self.cluster.scale(parallel_mode_config.num_workers) self._initialize_workers(config, parallel_mode_config) def _init_slurm_cluster( self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig ): - """Initialize a SLURM-based Dask cluster with GPU allocation.""" + """Initialize a SLURM-based Dask cluster with GPU allocation. + + Note: Dask's integration with SLURM currently only supports allocating single entire node + at a time, with each node running as a single SLURM task. This implementation adapts + to that limitation by requesting complete nodes and running multiple workers (one per GPU) + within each node. If our cluster eventually supports GRES (Generic Resource) scheduling, + this implementation could be modified to allow for more granular GPU allocation across + nodes rather than requiring full node allocation (I think, needs to be tested). + """ # Lazy import because dask_jobqueue cannot be started in a subprocess, which # happens e.g. with streamlit - from dask_jobqueue import SLURMCluster + from dask_jobqueue.slurm import SLURMCluster + + # Validate that num_workers is divisible by num_gpus_per_node + num_gpus_per_node = parallel_mode_config.num_gpus_per_node + if parallel_mode_config.num_workers % num_gpus_per_node != 0: + raise ValueError( + f"Number of workers ({parallel_mode_config.num_workers}) must be divisible by " + f"num_gpus_per_node ({num_gpus_per_node}). We assume each node has {num_gpus_per_node} GPUs, " + f"and current dask-jobqueue infrastructure only supports allocating whole nodes. " + ) + # TODO: add support for gres when available in our cluster for partial node allocation + + # Calculate number of jobs needed (each job = 1 slurm node with num_gpus_per_node GPUs) + num_jobs = parallel_mode_config.num_workers // num_gpus_per_node + + log_dir = parallel_mode_config.log_directory + os.makedirs(log_dir, exist_ok=True) + + # Calculate total memory needed per node (memory_per_worker * num_gpus_per_node) + memory_per_worker = parallel_mode_config.memory_per_worker + MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB") + value = int( + memory_per_worker[:-MEMORY_UNIT_LENGTH] + ) # Get numeric value by removing last 2 chars (e.g. "GB") + unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB") + assert len(unit) == MEMORY_UNIT_LENGTH, ( + f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}" + ) + total_memory = f"{value * parallel_mode_config.num_gpus_per_node}{unit}" self.cluster = SLURMCluster( - cores=parallel_mode_config.num_cpus_per_worker, - memory=parallel_mode_config.memory, - processes=1, # Single dask worker per slurm worker + cores=parallel_mode_config.num_cpus_per_worker * num_gpus_per_node, + memory=total_memory, + processes=num_gpus_per_node, # Each job runs num_gpus_per_node dask workers (one per GPU) walltime=parallel_mode_config.walltime, - job_extra=[ - "--gres=gpu:1" - ], # 1 GPU per worker seems to be the common case for now - log_directory=parallel_mode_config.log_directory, + job_extra_directives=[ + "--nodes=1", # Always request 1 node per job + "--exclusive", # Exclusive node access + "--mem=0", # Use all available memory + f"--cpus-per-task={parallel_mode_config.num_cpus_per_worker}", + f"-o {log_dir}/job_%j_task_%t.out", + f"-e {log_dir}/job_%j_task_%t.err", + ], + log_directory=log_dir, ) + + # Scale jobs to the required number of jobs + self.cluster.scale(jobs=num_jobs) self._initialize_workers(config, parallel_mode_config) def _initialize_workers( self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig ): - self.cluster.scale(parallel_mode_config.num_workers) self.client = Client(self.cluster) self.client.wait_for_workers(parallel_mode_config.num_workers) @@ -542,7 +647,6 @@ def get_cuda_visible_devices() -> int | None: if "," in device: device = device.split(",", maxsplit=1)[0] os.environ["CUDA_VISIBLE_DEVICES"] = device - os.environ["CUDA_VISIBLE_DEVICES"] = device return int(device) return None @@ -558,12 +662,10 @@ def get_cuda_visible_devices() -> int | None: worker_ids = [] for rank, (worker_address, worker_data) in enumerate(sorted_workers.items()): worker_id = worker_data["id"] + # On some occasions, dask SLURM integration auto assigns CUDA_VISIBLE_DEVICES, otherwise we set it here worker_cuda_device = worker_to_cuda_device[worker_address] if worker_cuda_device is None: - assert ( - parallel_mode_config.execution_mode != ExecutionMode.SLURM_CLUSTER - ), "CUDA_VISIBLE_DEVICES should be pre set for SLURM workers." - worker_cuda_device = rank + worker_cuda_device = rank % parallel_mode_config.num_gpus_per_node parallel_worker_config = ParallelWorkerConfig( rank=rank, diff --git a/src/ldp/nn/handlers/transformer_handler_fsdp2.py b/src/ldp/nn/handlers/transformer_handler_fsdp2.py new file mode 100644 index 00000000..e29ca7e0 --- /dev/null +++ b/src/ldp/nn/handlers/transformer_handler_fsdp2.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +import asyncio +import atexit +import logging +import os +import time +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from functools import partial, wraps +from pathlib import Path +from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, overload + +import torch +import torch.distributed as dist +import tree +from dask import config +from dask.distributed import Actor, ActorFuture, Client +from distributed.utils import sync + +try: + assert torch.__version__ >= "2.6.0", "FSDP2 requires PyTorch 2.6.0 or higher" + from torch.distributed.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + OffloadPolicy, + fully_shard, + register_fsdp_forward_method, + ) +except (ImportError, AssertionError) as e: + raise ImportError(f"FSDP2 requires PyTorch 2.6.0 or higher: {e}") from e +from transformers import PreTrainedModel +from transformers.generation.utils import GenerateDecoderOnlyOutput +from transformers.tokenization_utils_base import BatchEncoding + +from ldp.graph.async_torch import AsyncBufferedWorker, AsyncTorchModule +from ldp.nn.handlers.chunking import TensorChunker +from ldp.nn.handlers.module_handler import ModuleExecutionInterface, ModuleHandler +from ldp.nn.lm_config import TorchDType +from ldp.nn.utils import set_seed + +from .transformer_handler import ( + AsyncTransformer, + ExecutionMode, + FSDPConfig, + LMType, + ParallelModeConfig, + ParallelWorkerConfig, + TransformerHandlerConfig, + _get_data_device, + _get_tokenized_inputs, + _move_tensor, + _process_outputs, + maybe_set_tokenizer_chat_template, +) + +logger = logging.getLogger(__name__) + +config.set({ + # We have no use for rebooting workers in aviary for now, and rebooting workers + # is annoying when debugging. + "distributed.scheduler.allowed-failures": 0, + # FSDP forward/backward passes can take way longer than the default warning at 3s + "distributed.admin.tick.limit": "30s", + # Gives us more time to debug a downed worker. TODO: see if there are negative consequences + # of having this always enabled + "distributed.comm.timeouts.connect": "300s", + "distributed.comm.timeouts.tcp": "1200s", +}) + +TReturn = TypeVar("TReturn") +TParams = ParamSpec("TParams") + + +class AsyncTransformerInterface(ModuleExecutionInterface, AsyncTorchModule, ABC): + """Base class for async interactions with a transformer model.""" + + @abstractmethod + async def __call__( # type: ignore[override] + self, + inputs: str | BatchEncoding | list[dict], + tools_json: list[dict] | None = None, + **kwargs, + ) -> tuple[str, torch.Tensor]: + """Call the transformer on a single input, which may be encoded.""" + + @staticmethod + def model_generate(model: PreTrainedModel, *args, **kwargs): + """A method that can be used as module_call_fn to sample from an LLM.""" + logger.debug( + f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" + f" {os.environ.get('RANK')}" + ) + return model.generate( + *args, + **kwargs, + pad_token_id=model.config.pad_token_id, # not always set properly by .generate() + eos_token_id=model.config.eos_token_id, + ) + + +class TransformerHandler(ModuleHandler): + def __init__(self, config: TransformerHandlerConfig): + # Maybe this should be configurable? Hard to isolate the effect though + torch.set_float32_matmul_precision("high") + + self.config = config + # Use local_rank to resolve model location only in the main process *for each node* + config.lm_config.resolve_model_location(is_main_process=self.local_rank == 0) + + match config.lm_type: + case LMType.GENERATION: + tokenizer, model = config.lm_config.get_causal_lm() + # On left for https://github.com/huggingface/transformers/pull/7552 + # ^ that seems to work for most HF models w/ absolute position embeddings + # Left padding always works for relative position embeddings + tokenizer.padding_side = "left" + case LMType.REGRESSION: + tokenizer, model = config.lm_config.get_regression_lm() + case _: + assert_never(config.lm_type) + super().__init__(model) + self.tokenizer = tokenizer + maybe_set_tokenizer_chat_template( + self.tokenizer, self.config.lm_config.chat_template + ) + + self._setup_fsdp() + + if config.checkpoint is not None: + self.load_checkpoint(config.checkpoint) + + def _setup_fsdp(self): + """Set up FSDP2 module wrapping.""" + if not dist.is_initialized(): + # For single device usage, just move to device directly + device = self.config.lm_config.device + self.module = self.module.to(device) + return + + # Setup mixed precision policy if needed + mp_policy = None + if self.config.lm_config.dtype == TorchDType.bf16: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + output_dtype=torch.bfloat16, + ) + + # Apply FSDP2 wrapping using new API + fsdp_config = self.config.parallel_mode_config or FSDPConfig() + offload_policy = ( + CPUOffloadPolicy(pin_memory=True) + if fsdp_config.offload_cpu + else OffloadPolicy() + ) + + self.module = fully_shard( + self.module, + mesh=None, # Maybe we activate it later, see https://pytorch.org/docs/stable/distributed.html#torch.distributed.device_mesh.DeviceMesh + reshard_after_forward=fsdp_config.reshard_after_forward, + mp_policy=mp_policy, + offload_policy=offload_policy, + ) + + # Register model.generate as an FSDP forward method to handle generation correctly + register_fsdp_forward_method(self.module, "generate") + + @property + def local_rank(self) -> int: + return int(os.getenv("LOCAL_RANK", "0")) + + def load_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + logger.info(f'Loading checkpoint from "{ckpt}"') + start_time = time.perf_counter() + ckpt_path = Path(ckpt) + if ckpt_path.is_dir(): + # Assume it's a directory containing sharded state dict + state_dict = torch.load( + ckpt_path / f"rank{self.local_rank}_checkpoint.pt", map_location="cpu" + ) + self.module.load_state_dict(state_dict) + else: + # Assume it's a single file containing a full state dict + state_dict = torch.load(ckpt, map_location="cpu") + # Load the state dict - will automatically handle the sharding + self.module.load_state_dict(state_dict) + + self.barrier() + logger.info(f"Loading checkpoint took {time.perf_counter() - start_time:.2f}s") + torch.cuda.empty_cache() + + def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + start_time = time.perf_counter() + ckpt_path = Path(ckpt) + ckpt_path.mkdir(parents=True, exist_ok=True) + state_dict = self.module.state_dict() + if dist.is_initialized(): + torch.save(state_dict, ckpt_path / f"rank{self.local_rank}_checkpoint.pt") + else: + torch.save(state_dict, ckpt_path / "checkpoint.pt") + + self.barrier() + logger.info(f"Saving checkpoint took {time.perf_counter() - start_time:.2f}s") + + @staticmethod + def barrier() -> None: + if dist.is_initialized(): + dist.barrier() + + +class ParallelTransformerHandler(TransformerHandler): + def __init__( + self, + config: TransformerHandlerConfig, + parallel_worker_config: ParallelWorkerConfig, + ): + parallel_worker_config.set_env_vars() + dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}")) + dist.barrier() + torch.cuda.set_device(self.local_rank) + dist.barrier() + self.worker_config = parallel_worker_config + super().__init__(config) + + def set_seed(self, seed: int) -> None: + """Set the seed for the current worker.""" + set_seed(seed) + + def _exec_func( + self, + func: Callable[Concatenate[Self, TParams], TReturn] | str, + *args, + **kwargs, + ) -> TReturn: + # data will be on CPU when sent from controller + data_device = _get_data_device() + to_device = partial(_move_tensor, device=data_device) + args = tree.map_structure(to_device, args) + kwargs = tree.map_structure(to_device, kwargs) + + try: + with torch.autocast( + device_type=self.module.device.type, + dtype=torch.bfloat16 + if self.config.lm_config.dtype == TorchDType.bf16 + else torch.float32, + ): + res = ( + getattr(self, func)(*args, **kwargs) + if isinstance(func, str) + else func(self, *args, **kwargs) + ) + + # Needed to prevent GPU memory leak to the main process scheduling the workers + if isinstance(res, GenerateDecoderOnlyOutput): + res.past_key_values = None + res["past_key_values"] = None + + to_cpu = partial(_move_tensor, device=torch.device("cpu")) + return tree.map_structure(to_cpu, res) + except Exception as e: + # Re-raise the exception with traceback preserved. For some exceptions, Dask + # modifies or loses the original traceback when crossing process boundaries. + # RuntimeError preserves the traceback when using with_traceback() of original + # exception. + raise RuntimeError(str(e)).with_traceback(e.__traceback__) # noqa: B904 + + def __del__(self) -> None: + if dist.is_initialized(): + dist.destroy_process_group() + + +class ParallelAsyncTransformer(AsyncTransformerInterface): + def __init__(self, config: TransformerHandlerConfig): + self._initialized = False + + parallel_mode_config = config.parallel_mode_config + if not parallel_mode_config: + raise ValueError("Parallel mode config must be provided.") + self.config = config + self.tokenizer = config.lm_config.get_tokenizer() + maybe_set_tokenizer_chat_template( + self.tokenizer, self.config.lm_config.chat_template + ) + + match parallel_mode_config.execution_mode: + # TODO: see if we can just access `parallel_mode_config` as a + # `config` attribute instead of passing both. + case ExecutionMode.LOCAL_MACHINE: + self._init_local_cluster(config, parallel_mode_config) + case ExecutionMode.SLURM_CLUSTER: + self._init_slurm_cluster(config, parallel_mode_config) + case _: + assert_never(parallel_mode_config.execution_mode) + + self._initialized = True + + atexit.register(self.teardown) + + # don't call AsyncTorchModule.__init__ because we don't need to set up module[_call_fn] + AsyncBufferedWorker.__init__( + self, + batch_size=config.batch_size, + max_wait_interval=config.max_wait_interval, + collate_fn=config.collate_fn, + decollate_fn=config.decollate_fn, + ) + + def handler_call_fn(handler: ParallelTransformerHandler, *args, **kwargs): + return config.module_call_fn(handler.module, *args, **kwargs) + + self.handler_call_fn = handler_call_fn + + def _init_local_cluster( + self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig + ): + """Initialize a Dask cluster on local machine.""" + # lazy import since dask-cuda only works on Linux machines + from dask_cuda import LocalCUDACluster + + kwargs = {} + if os.environ.get("USE_UCX"): + kwargs = { + "protocol": "ucx", + "enable_tcp_over_ucx": True, + "enable_infiniband": True, + "enable_nvlink": True, + } + + self.cluster = LocalCUDACluster( + n_workers=parallel_mode_config.num_workers, + threads_per_worker=parallel_mode_config.num_cpus_per_worker, + host=parallel_mode_config.scheduler_addr, + port=parallel_mode_config.scheduler_port, + memory_limit=None, # do not let Dask manage memory - if we OOM, we OOM + device_memory_limit=0, # Disable gpu memory spilling. Should be handled by FSDP + **kwargs, + ) + self.cluster.scale(parallel_mode_config.num_workers) + self._initialize_workers(config, parallel_mode_config) + + def _init_slurm_cluster( + self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig + ): + """Initialize a SLURM-based Dask cluster with GPU allocation. + + Note: Dask's integration with SLURM currently only supports allocating single entire node + at a time, with each node running as a single SLURM task. This implementation adapts + to that limitation by requesting complete nodes and running multiple workers (one per GPU) + within each node. If our cluster eventually supports GRES (Generic Resource) scheduling, + this implementation could be modified to allow for more granular GPU allocation across + nodes rather than requiring full node allocation (I think, needs to be tested). + """ + # Lazy import because dask_jobqueue cannot be started in a subprocess, which + # happens e.g. with streamlit + from dask_jobqueue.slurm import SLURMCluster + + # Validate that num_workers is divisible by num_gpus_per_node + num_gpus_per_node = parallel_mode_config.num_gpus_per_node + if parallel_mode_config.num_workers % num_gpus_per_node != 0: + raise ValueError( + f"Number of workers ({parallel_mode_config.num_workers}) must be divisible by " + f"num_gpus_per_node ({num_gpus_per_node}). We assume each node has {num_gpus_per_node} GPUs, " + f"and current dask-jobqueue infrastructure only supports allocating whole nodes. " + ) + # TODO: add support for gres when available in our cluster for partial node allocation + + # Calculate number of jobs needed (each job = 1 slurm node with num_gpus_per_node GPUs) + num_jobs = parallel_mode_config.num_workers // num_gpus_per_node + + log_dir = parallel_mode_config.log_directory + os.makedirs(log_dir, exist_ok=True) + + # Calculate total memory needed per node (memory_per_worker * num_gpus_per_node) + memory_per_worker = parallel_mode_config.memory_per_worker + MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB") + value = int( + memory_per_worker[:-MEMORY_UNIT_LENGTH] + ) # Get numeric value by removing last 2 chars (e.g. "GB") + unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB") + assert len(unit) == MEMORY_UNIT_LENGTH, ( + f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}" + ) + total_memory = f"{value * parallel_mode_config.num_gpus_per_node}{unit}" + + self.cluster = SLURMCluster( + cores=parallel_mode_config.num_cpus_per_worker * num_gpus_per_node, + memory=total_memory, + processes=num_gpus_per_node, # Each job runs num_gpus_per_node dask workers (one per GPU) + walltime=parallel_mode_config.walltime, + job_extra_directives=[ + "--nodes=1", # Always request 1 node per job + "--exclusive", # Exclusive node access + "--mem=0", # Use all available memory + f"--cpus-per-task={parallel_mode_config.num_cpus_per_worker}", + f"-o {log_dir}/job_%j_task_%t.out", + f"-e {log_dir}/job_%j_task_%t.err", + ], + log_directory=log_dir, + ) + + # Scale jobs to the required number of jobs + self.cluster.scale(jobs=num_jobs) + self._initialize_workers(config, parallel_mode_config) + + def _initialize_workers( + self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig + ): + self.client = Client(self.cluster) + self.client.wait_for_workers(parallel_mode_config.num_workers) + + def get_cuda_visible_devices() -> int | None: + device = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if device is not None: + # If has several devices, assume the first one is the one to use for that worker + if "," in device: + device = device.split(",", maxsplit=1)[0] + os.environ["CUDA_VISIBLE_DEVICES"] = device + return int(device) + return None + + worker_to_cuda_device = self.client.run(get_cuda_visible_devices) + workers_info = self.client.scheduler_info()["workers"] + sorted_workers = dict( + sorted(workers_info.items(), key=lambda item: item[1]["id"]) + ) + # The first worker is the master in the torch distributed setup + master_addr = next(iter(sorted_workers.values()))["host"] + + futures = [] + worker_ids = [] + for rank, (worker_address, worker_data) in enumerate(sorted_workers.items()): + worker_id = worker_data["id"] + # On some occasions, dask SLURM integration auto assigns CUDA_VISIBLE_DEVICES, otherwise we set it here + worker_cuda_device = worker_to_cuda_device[worker_address] + if worker_cuda_device is None: + worker_cuda_device = rank % parallel_mode_config.num_gpus_per_node + + parallel_worker_config = ParallelWorkerConfig( + rank=rank, + world_size=parallel_mode_config.num_workers, + local_rank=worker_cuda_device, + master_addr=master_addr, + master_port=parallel_mode_config.torch_port, + **parallel_mode_config.model_dump(), + ) + future_op = self.client.submit( + ParallelTransformerHandler, + config=config, + parallel_worker_config=parallel_worker_config, + workers=[worker_id], + actor=True, + ) + futures.append(future_op) + worker_ids.append(worker_id) + + self.actors: list[Actor] = self._client_gather(futures) + self.worker_ids = worker_ids + + async def __call__( + self, + inputs: str | BatchEncoding | list[dict] | None = None, + tools_json: list[dict] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> tuple[str, torch.Tensor]: + if inputs is None: + inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + inputs_tokenized = _get_tokenized_inputs(self.tokenizer, inputs, tools_json) + inputs_len = inputs_tokenized["input_ids"].shape[1] + outputs = await AsyncBufferedWorker.__call__(self, **inputs_tokenized, **kwargs) + AsyncTransformer._maybe_finalize_logits_processors( + kwargs.get("logits_processor"), outputs + ) + + return _process_outputs(self.config, self.tokenizer, outputs, inputs_len) + + async def _batched_call(self, batch_kwargs: dict[str, Any]): + return self._submit_and_gather( + self.handler_call_fn, **batch_kwargs, split_data=True + ) + + def _submit_and_gather( + self, + func: Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn] | str, + *args, + split_data: bool = False, + **kwargs, + ) -> list[TReturn]: + """Submit a function to all workers and gather the results. + + Args: + func: The function to send to each worker. If a string is provided, + then getattr(handler, func) is used. If func is not a string, + the first argument must be the ParallelTransformerHandler that it will + be executed on. + split_data: If True, split the data between workers. If False, + send the same data to all workers. + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + The gathered results from the workers. + """ + if split_data: + chunker = TensorChunker( + num_chunks=len(self.actors), + ) + split_args, split_kwargs, dummy_flags = chunker.chunkify(*args, **kwargs) + else: + split_args = [args] * len(self.actors) + split_kwargs = [kwargs] * len(self.actors) + + futures = [ + handler._exec_func( + func, + *args_i, + **kwargs_i, + ) + for handler, worker_id, args_i, kwargs_i in zip( + self.actors, self.worker_ids, split_args, split_kwargs, strict=True + ) + ] + results: list[TReturn] = self._client_gather(futures) + + if split_data: + return chunker.dechunkify(results, dummy_flags) + return results + + def wrap_afunc( + self, + func: Callable[ + Concatenate[ParallelTransformerHandler, TParams], Awaitable[TReturn] + ], + **kwargs, + ) -> Callable[TParams, Awaitable[TReturn]]: + raise NotImplementedError( + "ParallelAsyncTransformer does not implement wrap_afunc(). " + "Wrap a synchronous function with wrap_func() instead." + ) + + @overload + def wrap_func( + self, + *, + worker_agg_fn: Callable[[list[TReturn]], TReturn] | None = None, + **kwargs, + ) -> Callable[ + [Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn]], + Callable[TParams, TReturn], + ]: ... + + @overload + def wrap_func( + self, + func: Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn], + *, + worker_agg_fn: Callable[[list[TReturn]], TReturn] | None = None, + **kwargs, + ) -> Callable[TParams, TReturn]: ... + + def wrap_func( + self, + func: ( + Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn] | None + ) = None, + *, + worker_agg_fn: Callable[[list[TReturn]], TReturn] | None = None, + **kwargs, + ) -> Callable: + """Wrap a function to execute on all workers and return gathered results. + + Args: + func: The function to wrap. + worker_agg_fn: A function to aggregate the results from all workers. + kwargs: Arguments that are discarded. Included here to enable a + subclass to add additional arguments. + """ + if worker_agg_fn is None: + raise ValueError("worker_agg_fn must be provided.") + + if func is None: + return partial(self.wrap_func, worker_agg_fn=worker_agg_fn, **kwargs) + + @wraps(func) + def wrapped_func(*args, **kwargs) -> TReturn: + return worker_agg_fn(self._submit_and_gather(func, *args, **kwargs)) + + return wrapped_func + + def state_dict(self, **kwargs) -> dict[str, torch.Tensor]: + """Get consolidated state dict from all workers. + + With FSDP2, we need to manually consolidate the state dict + """ + + def state_dict_worker( + handler: ParallelTransformerHandler, + ) -> dict[str, torch.Tensor]: + state_dict = handler.module.state_dict() + # Convert DTensors to full tensors + for key, tensor in state_dict.items(): + if hasattr(tensor, "full_tensor"): + state_dict[key] = tensor.full_tensor() + return state_dict + + # Only need the state dict from rank 0 + state_dict = self._submit_and_gather(state_dict_worker, **kwargs)[0] + return {k: v.cpu() for k, v in state_dict.items()} + + def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: + # For some reason, Dask hangs when we pass a large object (e.g. state_dict) + # directly to the workers. I can replicate it with the following: + # + # @handler.wrap_func + # def hello(handler, _): + # print("hello") + # + # hello([0] * 1_000_000) + # NOTE: this does not seem to be FSDP-related, as the issue didn't go away when + # I disabled FSDP. + raise NotImplementedError( + "ParallelAsyncTransformer.load_state_dict() is not implemented yet. It is" + " recommended to use .save_checkpoint() and .load_checkpoint() instead. " + ) + + def load_checkpoint(self, ckpt: os.PathLike | str) -> None: + self._submit_and_gather("load_checkpoint", ckpt) + + def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: + self._submit_and_gather("save_checkpoint", ckpt, **kwargs) + + def teardown(self) -> None: + if self._initialized: + self.client.shutdown() + self.cluster.close() + del self.client + del self.cluster + self._initialized = False + + def __del__(self) -> None: + self.teardown() + + @staticmethod + def _wrap_dask_future(dask_future: ActorFuture): + """Converts a Dask ActorFuture into an awaitable asyncio.Future.""" + loop = asyncio.get_running_loop() + return asyncio.ensure_future(loop.run_in_executor(None, dask_future.result)) + + @staticmethod + def _raise_exceptions(done, pending, wrapped_futures): + exceptions = [] + for future in done: + exc = future.exception() + if exc: + exceptions.append(exc) + if exceptions: + if len(exceptions) == 1: + raise exceptions[0] + raise ExceptionGroup("Multiple actor exceptions", exceptions) + + if pending: + pending_indices = sorted([wrapped_futures.index(p) for p in pending]) + raise TimeoutError( + f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " + f"still pending. Pending task indices: {pending_indices}" + ) + + async def _client_gather_async(self, futures): + """Gather results from futures, propagating exceptions as they arrive. + + Unlike client.gather() which waits for all futures to complete before raising + any exceptions, this method processes futures as they complete and raises + exceptions immediately. This is crucial when using FSDP where workers may + be stuck waiting for each other when one worker crashes, causing long hangs. + + Note: Dask Actors currently have an issue where they're not working properly with + dask.gather() and can cause blocking issues or hide worker errors. This implementation + works around those limitations. + """ + try: + wrapped_futures = [self._wrap_dask_future(f) for f in futures] + + # Use asyncio.wait with FIRST_EXCEPTION instead of gather + done, pending = await asyncio.wait( + wrapped_futures, timeout=1200, return_when=asyncio.FIRST_EXCEPTION + ) + + self._raise_exceptions(done, pending, wrapped_futures) + + return await asyncio.gather(*wrapped_futures) + except Exception: + logger.exception("Error in dask workers: %s") + for future in wrapped_futures: + future.cancel() + self.teardown() + # sys.exit(1) would wait for dask to finish, which can cause hanging + # when workers are in a deadlock. Use os._exit to force immediate termination + # TODO: this is more of a hack, we should propagate special exception that is + # not caught by the rollout manager. + os._exit(1) + + def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: + # Use distributed.utils.sync to run the async function in the current thread + return sync(self.client.loop, self._client_gather_async, futures) # type: ignore[arg-type] diff --git a/tests/test_nn_models.py b/tests/test_nn_models.py index 55917ce3..8e129e93 100644 --- a/tests/test_nn_models.py +++ b/tests/test_nn_models.py @@ -155,7 +155,6 @@ async def test_generation( lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=4, - module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=tokenizer.pad_token_id, @@ -201,7 +200,6 @@ def test_state_dicts( lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=1, - module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=model_config.get_tokenizer().pad_token_id, @@ -242,7 +240,6 @@ def test_distributed_checkpoints(self, sharded: bool) -> None: lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=1, - module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=model_config.get_tokenizer().pad_token_id, @@ -332,7 +329,6 @@ def test_consistent_weights(self): lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=1, - module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=model_config.get_tokenizer().pad_token_id, From caf9edfc903fc69f631bec6b22b2f077fe7c7486 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 7 Apr 2025 09:50:41 -0500 Subject: [PATCH 27/34] SLURM + FSDP2 support --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 5293aadb..540b1d72 100644 --- a/.gitignore +++ b/.gitignore @@ -299,3 +299,6 @@ cython_debug/ **/version.py .vscode/ + +src/logs/ +slurm_logs/ From 94fba95ea4a7b06aa0cf1975a3e154716f82f275 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 7 Apr 2025 13:31:21 -0500 Subject: [PATCH 28/34] nit --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index 540b1d72..5b6f0b5f 100644 --- a/.gitignore +++ b/.gitignore @@ -300,5 +300,3 @@ cython_debug/ .vscode/ -src/logs/ -slurm_logs/ From c809dc3e39a7e0016225739a2203bfb1eba1c755 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 7 Apr 2025 13:41:08 -0500 Subject: [PATCH 29/34] nit --- src/ldp/nn/handlers/transformer_handler.py | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index ac600352..364e8025 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -181,7 +181,7 @@ class TransformerHandlerConfig(BaseModel): batch_size: int max_wait_interval: float = 0.1 - module_call_fn: Callable | None = None + _module_call_fn: Callable | None = None collate_fn: Callable decollate_fn: Callable @@ -193,21 +193,20 @@ class TransformerHandlerConfig(BaseModel): ), ) - @field_validator("module_call_fn", mode="before") - def set_default_module_call_fn(cls, v, info): - """Set default module_call_fn based on implementation if not provided.""" - if v is None: - # Get the implementation value from the validation context - implementation = info.data.get("implementation", TransformerImplementation.ACCELERATOR) - - # For now, both implementations use the same model_generate function - # but this allows for future differentiation based on implementation - if implementation == TransformerImplementation.FSDP2: + @property + def module_call_fn(self) -> Callable: + """Get the module call function based on implementation.""" + if self._module_call_fn is None: + if self.implementation == TransformerImplementation.FSDP2: from .transformer_handler_fsdp2 import AsyncTransformerInterface as FSDP2AsyncTransformerInterface return FSDP2AsyncTransformerInterface.model_generate else: # Default or ACCELERATOR return AsyncTransformerInterface.model_generate - return v + return self._module_call_fn + + @module_call_fn.setter + def module_call_fn(self, value: Callable | None) -> None: + self._module_call_fn = value def make_async_module(self, **kwargs) -> AsyncTransformerInterface: if self.parallel_mode_config: From 62a499f9a5400c5a098c438f2084150344ca2dd0 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Mon, 7 Apr 2025 13:44:31 -0500 Subject: [PATCH 30/34] nit --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 5b6f0b5f..0ab0d577 100644 --- a/.gitignore +++ b/.gitignore @@ -300,3 +300,6 @@ cython_debug/ .vscode/ +# Slurm log files +*.err +*.out From 0cdba00d5f9743fa0227b69445057bb163133a3f Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 23 Apr 2025 07:32:34 -0500 Subject: [PATCH 31/34] nits --- src/ldp/nn/handlers/transformer_handler.py | 18 +++++++++++------- .../nn/handlers/transformer_handler_fsdp2.py | 4 +++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 24ba5d6e..19041629 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -12,7 +12,7 @@ from enum import StrEnum, auto from functools import cache, partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never +from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, cast import accelerate import torch @@ -192,16 +192,19 @@ class TransformerHandlerConfig(BaseModel): "multiple devices/nodes. If not provided, will default to single-device." ), ) - + @property def module_call_fn(self) -> Callable: """Get the module call function based on implementation.""" if self._module_call_fn is None: if self.implementation == TransformerImplementation.FSDP2: - from .transformer_handler_fsdp2 import AsyncTransformerInterface as FSDP2AsyncTransformerInterface + from .transformer_handler_fsdp2 import ( + AsyncTransformerInterface as FSDP2AsyncTransformerInterface, + ) + return FSDP2AsyncTransformerInterface.model_generate - else: # Default or ACCELERATOR - return AsyncTransformerInterface.model_generate + # Default or ACCELERATOR + return AsyncTransformerInterface.model_generate return self._module_call_fn @module_call_fn.setter @@ -407,7 +410,9 @@ def __init__( parallel_worker_config: ParallelWorkerConfig, ): parallel_worker_config.set_env_vars() - dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}")) + dist.init_process_group( + backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}") + ) torch.cuda.set_device(self.local_rank) dist.barrier() self.worker_config = parallel_worker_config @@ -559,7 +564,6 @@ def _init_local_cluster( port=parallel_mode_config.scheduler_port, memory_limit=None, # do not let Dask manage memory - if we OOM, we OOM device_memory_limit=0, # Disable gpu memory spilling. Should be handled by FSDP - **kwargs, ) self.cluster.scale(parallel_mode_config.num_workers) self._initialize_workers(config, parallel_mode_config) diff --git a/src/ldp/nn/handlers/transformer_handler_fsdp2.py b/src/ldp/nn/handlers/transformer_handler_fsdp2.py index e29ca7e0..df9474ee 100644 --- a/src/ldp/nn/handlers/transformer_handler_fsdp2.py +++ b/src/ldp/nn/handlers/transformer_handler_fsdp2.py @@ -216,7 +216,9 @@ def __init__( parallel_worker_config: ParallelWorkerConfig, ): parallel_worker_config.set_env_vars() - dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}")) + dist.init_process_group( + backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}") + ) dist.barrier() torch.cuda.set_device(self.local_rank) dist.barrier() From 770bc0688763994b7090751aa784d2753defd578 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 29 Apr 2025 18:12:42 -0500 Subject: [PATCH 32/34] update fsdp2 backend --- pyproject.toml | 4 +- src/ldp/nn/__init__.py | 2 + src/ldp/nn/agent/simple_local_agent.py | 4 +- src/ldp/nn/graph/llm_call_op.py | 3 + src/ldp/nn/handlers/transformer_handler.py | 134 ++-- .../nn/handlers/transformer_handler_fsdp2.py | 597 +----------------- uv.lock | 50 +- 7 files changed, 115 insertions(+), 679 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f88e3c3..fd9b34df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ nn = [ "dask-jobqueue", "dask[distributed]", "tokenizers>0.20", - "torch>=2.5,<2.7", # Temporarily pin <2.6 until someone fixes our CI with torch 2.6 + "torch>=2.6,<2.7", # Use torch 2.6 "transformers>=4.46", "wandb", ] @@ -153,7 +153,7 @@ explicit_package_bases = true mypy_path = "$MYPY_CONFIG_FILE_DIR/src,$MYPY_CONFIG_FILE_DIR/packages/lmi/src" # Specifies the OS platform for the target program, for example darwin or win32 # (meaning OS X or Windows, respectively). The default is the current platform -# as revealed by Python’s sys.platform variable. +# as revealed by Python's sys.platform variable. platform = "linux" # Comma-separated list of mypy plugins. plugins = ["pydantic.mypy"] diff --git a/src/ldp/nn/__init__.py b/src/ldp/nn/__init__.py index 77b54103..01a59172 100644 --- a/src/ldp/nn/__init__.py +++ b/src/ldp/nn/__init__.py @@ -12,6 +12,7 @@ ParallelTransformerHandler, TransformerHandler, TransformerHandlerConfig, + TransformerImplementation, collate_fn_transformer_left_pad, collate_fn_transformer_right_pad, decollate_fn_transformer_decoder, @@ -35,6 +36,7 @@ "TorchDType", "TransformerHandler", "TransformerHandlerConfig", + "TransformerImplementation", "collate_fn_transformer_left_pad", "collate_fn_transformer_right_pad", "decollate_fn_transformer_decoder", diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index bf173b2a..2dd6ce12 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -15,6 +15,7 @@ from ldp.nn.handlers.chunking import TensorChunker from ldp.nn.handlers.transformer_handler import ( ParallelModeConfig, + TransformerImplementation, logits_to_logprobs, ) from ldp.nn.lm_config import LMConfig as _LMConfig @@ -31,7 +32,7 @@ class AgentLMConfig(_LMConfig): # distribution parallel_mode: ParallelModeConfig | None = None - + implementation: TransformerImplementation = TransformerImplementation.ACCELERATOR # sampling parameters temperature: float = 1.0 max_new_tokens: int = 50 @@ -80,6 +81,7 @@ def __init__( batch_size=self.llm_model.batch_size, max_wait_interval=self.llm_model.max_wait_interval, parallel_mode_config=self.llm_model.parallel_mode, + implementation=self.llm_model.implementation, ) async def init_state(self, tools: list[Tool]) -> SimpleAgentState: diff --git a/src/ldp/nn/graph/llm_call_op.py b/src/ldp/nn/graph/llm_call_op.py index 35e82ef5..6581b1d7 100644 --- a/src/ldp/nn/graph/llm_call_op.py +++ b/src/ldp/nn/graph/llm_call_op.py @@ -15,6 +15,7 @@ LMType, ParallelModeConfig, TransformerHandlerConfig, + TransformerImplementation, collate_fn_transformer_left_pad, decollate_fn_transformer_decoder, ) @@ -39,6 +40,7 @@ def __init__( batch_size: int = 1, max_wait_interval: float = 0.1, parallel_mode_config: ParallelModeConfig | None = None, + implementation: TransformerImplementation = TransformerImplementation.ACCELERATOR, ) -> None: super().__init__() @@ -50,6 +52,7 @@ def __init__( batch_size=batch_size, max_wait_interval=max_wait_interval, parallel_mode_config=parallel_mode_config, + implementation=implementation, # constant configuration lm_type=LMType.GENERATION, collate_fn=partial( diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 19041629..7dc14be7 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -12,7 +12,7 @@ from enum import StrEnum, auto from functools import cache, partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, cast +from typing import Concatenate, ParamSpec, Self, TypeVar, assert_never, Any import accelerate import torch @@ -29,6 +29,7 @@ FullStateDictConfig, FullyShardedDataParallel, MixedPrecision, + MixedPrecisionPolicy, ShardingStrategy, StateDictType, ) @@ -181,7 +182,7 @@ class TransformerHandlerConfig(BaseModel): batch_size: int max_wait_interval: float = 0.1 - _module_call_fn: Callable | None = None + module_call_fn: Callable collate_fn: Callable decollate_fn: Callable @@ -193,37 +194,14 @@ class TransformerHandlerConfig(BaseModel): ), ) - @property - def module_call_fn(self) -> Callable: - """Get the module call function based on implementation.""" - if self._module_call_fn is None: - if self.implementation == TransformerImplementation.FSDP2: - from .transformer_handler_fsdp2 import ( - AsyncTransformerInterface as FSDP2AsyncTransformerInterface, - ) - - return FSDP2AsyncTransformerInterface.model_generate - # Default or ACCELERATOR - return AsyncTransformerInterface.model_generate - return self._module_call_fn - - @module_call_fn.setter - def module_call_fn(self, value: Callable | None) -> None: - self._module_call_fn = value - def make_async_module(self, **kwargs) -> AsyncTransformerInterface: if self.parallel_mode_config: if self.implementation == TransformerImplementation.ACCELERATOR: return ParallelAsyncTransformer(config=self, **kwargs) if self.implementation == TransformerImplementation.FSDP2: - from .transformer_handler_fsdp2 import ( - ParallelTransformerHandler as FSDP2ParallelTransformerHandler, - ) + from .transformer_handler_fsdp2 import FSDP2ParallelAsyncTransformer - return cast( - "ParallelAsyncTransformer", - FSDP2ParallelTransformerHandler(config=self, **kwargs), - ) + return FSDP2ParallelAsyncTransformer(config=self, **kwargs) raise ValueError(f"Unsupported implementation: {self.implementation}") return AsyncTransformer(config=self, **kwargs) @@ -243,19 +221,22 @@ async def __call__( # type: ignore[override] @staticmethod def model_generate(model: PreTrainedModel, *args, **kwargs): """A method that can be used as module_call_fn to sample from an LLM.""" - # Summoning params per https://github.com/pytorch/pytorch/issues/100069 - # If model is not FSDP, this context manager is a no-op. - with FullyShardedDataParallel.summon_full_params(model, recurse=False): - logger.debug( - f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" - f" {os.environ.get('RANK')}" - ) - return model.generate( - *args, - **kwargs, - pad_token_id=model.config.pad_token_id, # not always set properly by .generate() - eos_token_id=model.config.eos_token_id, + logger.info( + f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" + f" {os.environ.get('RANK')}" + ) + if model.training: + logger.warning( + f"Model is in training mode at rank {os.environ.get('RANK')}, setting to eval mode" ) + model.eval() + + return model.generate( + *args, + **kwargs, + pad_token_id=model.config.pad_token_id, # not always set properly by .generate() + eos_token_id=model.config.eos_token_id, + ) class TransformerHandler(ModuleHandler): @@ -284,12 +265,12 @@ def __init__(self, config: TransformerHandlerConfig): self.tokenizer, self.config.lm_config.chat_template ) - self._setup_accelerator() + self._setup_fsdp() if config.checkpoint is not None: self.load_checkpoint(config.checkpoint) - def _setup_accelerator(self): + def _setup_fsdp(self): self.accelerator = accelerate.Accelerator( # This has to be disabled because accelerator wraps forward() to upcast outputs to fp32. That # causes problems with generation, where the cache is expected to be in the same dtype as the model. @@ -344,6 +325,7 @@ def __init__(self, config: TransformerHandlerConfig): max_wait_interval=config.max_wait_interval, collate_fn=config.collate_fn, decollate_fn=config.decollate_fn, + module_call_fn=config.module_call_fn, ) async def __call__( @@ -410,22 +392,21 @@ def __init__( parallel_worker_config: ParallelWorkerConfig, ): parallel_worker_config.set_env_vars() - dist.init_process_group( - backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}") - ) + torch.cuda.set_device(self.local_rank) - dist.barrier() + dist.init_process_group(backend="nccl") + self.worker_config = parallel_worker_config super().__init__(config) - def _setup_accelerator(self): + def _setup_fsdp(self): bf16 = self.config.lm_config.dtype == TorchDType.bf16 mixed_precision = None if bf16: bf16_ready = ( torch.version.cuda - # and torch.cuda.is_bf16_supported() # TODO add it back + and torch.cuda.is_bf16_supported() and torch.version.cuda >= "11.0" and dist.is_nccl_available() and nccl.version() >= (2, 10) @@ -439,10 +420,17 @@ def _setup_accelerator(self): buffer_dtype=torch.bfloat16, ) + mixed_precision = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + output_dtype=torch.bfloat16, + ) + self.accelerator = accelerate.Accelerator( - # See note in TransformerHandler._setup_accelerator() about this + # See note in TransformerHandler._setup_fsdp() about this # mixed_precision=("bf16" if bf16 else "no"), fsdp_plugin=accelerate.FullyShardedDataParallelPlugin( + fsdp_version=2, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision_policy=mixed_precision, auto_wrap_policy="transformer_based_wrap", @@ -452,13 +440,16 @@ def _setup_accelerator(self): sync_module_states=self.worker_config.cpu_ram_efficient_loading, state_dict_type=self.worker_config.state_dict_type, backward_prefetch=self.worker_config.backward_prefetch, + reshard_after_forward=self.worker_config.reshard_after_forward, ), ) if self.config.lm_config.device == "meta": self.module = prepare_model_for_fsdp_with_meta_device(self.module) - self.module = self.accelerator.prepare(self.module) + # TODO: evaluation_mode=True gives perf boost. However we can't train, + # allow control over this param + self.module = self.accelerator.prepare_model(self.module) def set_seed(self, seed: int) -> None: """Set the seed for the current worker.""" @@ -518,6 +509,11 @@ def __init__(self, config: TransformerHandlerConfig): self.tokenizer, self.config.lm_config.chat_template ) + # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection + # post initialization. This prevents issues with forked processes wrongly detecting the + # default GPU as cuda:0 + os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" + match parallel_mode_config.execution_mode: # TODO: see if we can just access `parallel_mode_config` as a # `config` attribute instead of passing both. @@ -553,10 +549,6 @@ def _init_local_cluster( # lazy import since dask-cuda only works on Linux machines from dask_cuda import LocalCUDACluster - # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection - # post initialization. This prevents issues with forked processes wrongly detecting the - # default GPU as cuda:0 - os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" self.cluster = LocalCUDACluster( n_workers=parallel_mode_config.num_workers, threads_per_worker=parallel_mode_config.num_cpus_per_worker, @@ -638,17 +630,18 @@ def _initialize_workers( self.client = Client(self.cluster) self.client.wait_for_workers(parallel_mode_config.num_workers) - def get_cuda_visible_devices() -> int | None: - device = os.environ.get("CUDA_VISIBLE_DEVICES", None) - if device is not None: - # If has several devices, assume the first one is the one to use for that worker - if "," in device: - device = device.split(",", maxsplit=1)[0] - os.environ["CUDA_VISIBLE_DEVICES"] = device - return int(device) - return None - - worker_to_cuda_device = self.client.run(get_cuda_visible_devices) + # TODO: enable when gres is enabled in our cluster + # def get_cuda_visible_devices() -> int | None: + # device = os.environ.get("CUDA_VISIBLE_DEVICES", None) + # if device is not None: + # # If has several devices, assume the first one is the one to use for that worker + # if "," in device: + # device = device.split(",", maxsplit=1)[0] + # os.environ["CUDA_VISIBLE_DEVICES"] = device + # return int(device) + # return None + + # worker_to_cuda_device = self.client.run(get_cuda_visible_devices) workers_info = self.client.scheduler_info()["workers"] sorted_workers = dict( sorted(workers_info.items(), key=lambda item: item[1]["id"]) @@ -658,12 +651,12 @@ def get_cuda_visible_devices() -> int | None: futures = [] worker_ids = [] - for rank, (worker_address, worker_data) in enumerate(sorted_workers.items()): + for rank, (_, worker_data) in enumerate(sorted_workers.items()): worker_id = worker_data["id"] # On some occasions, dask SLURM integration auto assigns CUDA_VISIBLE_DEVICES, otherwise we set it here - worker_cuda_device = worker_to_cuda_device[worker_address] - if worker_cuda_device is None: - worker_cuda_device = rank % parallel_mode_config.num_gpus_per_node + # worker_cuda_device = worker_to_cuda_device[worker_address] + # if worker_cuda_device is None: + worker_cuda_device = rank % parallel_mode_config.num_gpus_per_node parallel_worker_config = ParallelWorkerConfig( rank=rank, @@ -674,7 +667,7 @@ def get_cuda_visible_devices() -> int | None: **parallel_mode_config.model_dump(), ) future_op = self.client.submit( - ParallelTransformerHandler, + self._get_parallel_transformer_handler_cls(), config=config, parallel_worker_config=parallel_worker_config, workers=[worker_id], @@ -686,6 +679,9 @@ def get_cuda_visible_devices() -> int | None: self.actors: list[Actor] = self._client_gather(futures) self.worker_ids = worker_ids + def _get_parallel_transformer_handler_cls(self): + return ParallelTransformerHandler + async def __call__( self, inputs: str | BatchEncoding | list[dict] | None = None, diff --git a/src/ldp/nn/handlers/transformer_handler_fsdp2.py b/src/ldp/nn/handlers/transformer_handler_fsdp2.py index df9474ee..98e65f8b 100644 --- a/src/ldp/nn/handlers/transformer_handler_fsdp2.py +++ b/src/ldp/nn/handlers/transformer_handler_fsdp2.py @@ -1,22 +1,12 @@ from __future__ import annotations -import asyncio -import atexit import logging import os import time -from abc import ABC, abstractmethod -from collections.abc import Awaitable, Callable -from functools import partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, overload import torch import torch.distributed as dist -import tree -from dask import config -from dask.distributed import Actor, ActorFuture, Client -from distributed.utils import sync try: assert torch.__version__ >= "2.6.0", "FSDP2 requires PyTorch 2.6.0 or higher" @@ -29,107 +19,19 @@ ) except (ImportError, AssertionError) as e: raise ImportError(f"FSDP2 requires PyTorch 2.6.0 or higher: {e}") from e -from transformers import PreTrainedModel -from transformers.generation.utils import GenerateDecoderOnlyOutput -from transformers.tokenization_utils_base import BatchEncoding -from ldp.graph.async_torch import AsyncBufferedWorker, AsyncTorchModule -from ldp.nn.handlers.chunking import TensorChunker -from ldp.nn.handlers.module_handler import ModuleExecutionInterface, ModuleHandler from ldp.nn.lm_config import TorchDType -from ldp.nn.utils import set_seed from .transformer_handler import ( - AsyncTransformer, - ExecutionMode, FSDPConfig, - LMType, - ParallelModeConfig, - ParallelWorkerConfig, - TransformerHandlerConfig, - _get_data_device, - _get_tokenized_inputs, - _move_tensor, - _process_outputs, - maybe_set_tokenizer_chat_template, + ParallelAsyncTransformer, + ParallelTransformerHandler, ) logger = logging.getLogger(__name__) -config.set({ - # We have no use for rebooting workers in aviary for now, and rebooting workers - # is annoying when debugging. - "distributed.scheduler.allowed-failures": 0, - # FSDP forward/backward passes can take way longer than the default warning at 3s - "distributed.admin.tick.limit": "30s", - # Gives us more time to debug a downed worker. TODO: see if there are negative consequences - # of having this always enabled - "distributed.comm.timeouts.connect": "300s", - "distributed.comm.timeouts.tcp": "1200s", -}) - -TReturn = TypeVar("TReturn") -TParams = ParamSpec("TParams") - - -class AsyncTransformerInterface(ModuleExecutionInterface, AsyncTorchModule, ABC): - """Base class for async interactions with a transformer model.""" - - @abstractmethod - async def __call__( # type: ignore[override] - self, - inputs: str | BatchEncoding | list[dict], - tools_json: list[dict] | None = None, - **kwargs, - ) -> tuple[str, torch.Tensor]: - """Call the transformer on a single input, which may be encoded.""" - - @staticmethod - def model_generate(model: PreTrainedModel, *args, **kwargs): - """A method that can be used as module_call_fn to sample from an LLM.""" - logger.debug( - f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" - f" {os.environ.get('RANK')}" - ) - return model.generate( - *args, - **kwargs, - pad_token_id=model.config.pad_token_id, # not always set properly by .generate() - eos_token_id=model.config.eos_token_id, - ) - - -class TransformerHandler(ModuleHandler): - def __init__(self, config: TransformerHandlerConfig): - # Maybe this should be configurable? Hard to isolate the effect though - torch.set_float32_matmul_precision("high") - - self.config = config - # Use local_rank to resolve model location only in the main process *for each node* - config.lm_config.resolve_model_location(is_main_process=self.local_rank == 0) - - match config.lm_type: - case LMType.GENERATION: - tokenizer, model = config.lm_config.get_causal_lm() - # On left for https://github.com/huggingface/transformers/pull/7552 - # ^ that seems to work for most HF models w/ absolute position embeddings - # Left padding always works for relative position embeddings - tokenizer.padding_side = "left" - case LMType.REGRESSION: - tokenizer, model = config.lm_config.get_regression_lm() - case _: - assert_never(config.lm_type) - super().__init__(model) - self.tokenizer = tokenizer - maybe_set_tokenizer_chat_template( - self.tokenizer, self.config.lm_config.chat_template - ) - - self._setup_fsdp() - - if config.checkpoint is not None: - self.load_checkpoint(config.checkpoint) +class FSDP2ParallelTransformerHandler(ParallelTransformerHandler): def _setup_fsdp(self): """Set up FSDP2 module wrapping.""" if not dist.is_initialized(): @@ -140,7 +42,9 @@ def _setup_fsdp(self): # Setup mixed precision policy if needed mp_policy = None + logger.info(f"Setting up FSDP2 with dtype {self.config.lm_config.dtype}") if self.config.lm_config.dtype == TorchDType.bf16: + logger.info("Setting up FSDP2 with bfloat16 dtype") mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, @@ -166,10 +70,6 @@ def _setup_fsdp(self): # Register model.generate as an FSDP forward method to handle generation correctly register_fsdp_forward_method(self.module, "generate") - @property - def local_rank(self) -> int: - return int(os.getenv("LOCAL_RANK", "0")) - def load_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: logger.info(f'Loading checkpoint from "{ckpt}"') start_time = time.perf_counter() @@ -188,7 +88,6 @@ def load_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: self.barrier() logger.info(f"Loading checkpoint took {time.perf_counter() - start_time:.2f}s") - torch.cuda.empty_cache() def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: start_time = time.perf_counter() @@ -203,394 +102,10 @@ def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: self.barrier() logger.info(f"Saving checkpoint took {time.perf_counter() - start_time:.2f}s") - @staticmethod - def barrier() -> None: - if dist.is_initialized(): - dist.barrier() - - -class ParallelTransformerHandler(TransformerHandler): - def __init__( - self, - config: TransformerHandlerConfig, - parallel_worker_config: ParallelWorkerConfig, - ): - parallel_worker_config.set_env_vars() - dist.init_process_group( - backend="nccl", device_id=torch.device(f"cuda:{self.local_rank}") - ) - dist.barrier() - torch.cuda.set_device(self.local_rank) - dist.barrier() - self.worker_config = parallel_worker_config - super().__init__(config) - - def set_seed(self, seed: int) -> None: - """Set the seed for the current worker.""" - set_seed(seed) - - def _exec_func( - self, - func: Callable[Concatenate[Self, TParams], TReturn] | str, - *args, - **kwargs, - ) -> TReturn: - # data will be on CPU when sent from controller - data_device = _get_data_device() - to_device = partial(_move_tensor, device=data_device) - args = tree.map_structure(to_device, args) - kwargs = tree.map_structure(to_device, kwargs) - - try: - with torch.autocast( - device_type=self.module.device.type, - dtype=torch.bfloat16 - if self.config.lm_config.dtype == TorchDType.bf16 - else torch.float32, - ): - res = ( - getattr(self, func)(*args, **kwargs) - if isinstance(func, str) - else func(self, *args, **kwargs) - ) - - # Needed to prevent GPU memory leak to the main process scheduling the workers - if isinstance(res, GenerateDecoderOnlyOutput): - res.past_key_values = None - res["past_key_values"] = None - - to_cpu = partial(_move_tensor, device=torch.device("cpu")) - return tree.map_structure(to_cpu, res) - except Exception as e: - # Re-raise the exception with traceback preserved. For some exceptions, Dask - # modifies or loses the original traceback when crossing process boundaries. - # RuntimeError preserves the traceback when using with_traceback() of original - # exception. - raise RuntimeError(str(e)).with_traceback(e.__traceback__) # noqa: B904 - - def __del__(self) -> None: - if dist.is_initialized(): - dist.destroy_process_group() - - -class ParallelAsyncTransformer(AsyncTransformerInterface): - def __init__(self, config: TransformerHandlerConfig): - self._initialized = False - - parallel_mode_config = config.parallel_mode_config - if not parallel_mode_config: - raise ValueError("Parallel mode config must be provided.") - self.config = config - self.tokenizer = config.lm_config.get_tokenizer() - maybe_set_tokenizer_chat_template( - self.tokenizer, self.config.lm_config.chat_template - ) - - match parallel_mode_config.execution_mode: - # TODO: see if we can just access `parallel_mode_config` as a - # `config` attribute instead of passing both. - case ExecutionMode.LOCAL_MACHINE: - self._init_local_cluster(config, parallel_mode_config) - case ExecutionMode.SLURM_CLUSTER: - self._init_slurm_cluster(config, parallel_mode_config) - case _: - assert_never(parallel_mode_config.execution_mode) - - self._initialized = True - - atexit.register(self.teardown) - - # don't call AsyncTorchModule.__init__ because we don't need to set up module[_call_fn] - AsyncBufferedWorker.__init__( - self, - batch_size=config.batch_size, - max_wait_interval=config.max_wait_interval, - collate_fn=config.collate_fn, - decollate_fn=config.decollate_fn, - ) - - def handler_call_fn(handler: ParallelTransformerHandler, *args, **kwargs): - return config.module_call_fn(handler.module, *args, **kwargs) - - self.handler_call_fn = handler_call_fn - - def _init_local_cluster( - self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig - ): - """Initialize a Dask cluster on local machine.""" - # lazy import since dask-cuda only works on Linux machines - from dask_cuda import LocalCUDACluster - - kwargs = {} - if os.environ.get("USE_UCX"): - kwargs = { - "protocol": "ucx", - "enable_tcp_over_ucx": True, - "enable_infiniband": True, - "enable_nvlink": True, - } - - self.cluster = LocalCUDACluster( - n_workers=parallel_mode_config.num_workers, - threads_per_worker=parallel_mode_config.num_cpus_per_worker, - host=parallel_mode_config.scheduler_addr, - port=parallel_mode_config.scheduler_port, - memory_limit=None, # do not let Dask manage memory - if we OOM, we OOM - device_memory_limit=0, # Disable gpu memory spilling. Should be handled by FSDP - **kwargs, - ) - self.cluster.scale(parallel_mode_config.num_workers) - self._initialize_workers(config, parallel_mode_config) - - def _init_slurm_cluster( - self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig - ): - """Initialize a SLURM-based Dask cluster with GPU allocation. - - Note: Dask's integration with SLURM currently only supports allocating single entire node - at a time, with each node running as a single SLURM task. This implementation adapts - to that limitation by requesting complete nodes and running multiple workers (one per GPU) - within each node. If our cluster eventually supports GRES (Generic Resource) scheduling, - this implementation could be modified to allow for more granular GPU allocation across - nodes rather than requiring full node allocation (I think, needs to be tested). - """ - # Lazy import because dask_jobqueue cannot be started in a subprocess, which - # happens e.g. with streamlit - from dask_jobqueue.slurm import SLURMCluster - - # Validate that num_workers is divisible by num_gpus_per_node - num_gpus_per_node = parallel_mode_config.num_gpus_per_node - if parallel_mode_config.num_workers % num_gpus_per_node != 0: - raise ValueError( - f"Number of workers ({parallel_mode_config.num_workers}) must be divisible by " - f"num_gpus_per_node ({num_gpus_per_node}). We assume each node has {num_gpus_per_node} GPUs, " - f"and current dask-jobqueue infrastructure only supports allocating whole nodes. " - ) - # TODO: add support for gres when available in our cluster for partial node allocation - - # Calculate number of jobs needed (each job = 1 slurm node with num_gpus_per_node GPUs) - num_jobs = parallel_mode_config.num_workers // num_gpus_per_node - - log_dir = parallel_mode_config.log_directory - os.makedirs(log_dir, exist_ok=True) - - # Calculate total memory needed per node (memory_per_worker * num_gpus_per_node) - memory_per_worker = parallel_mode_config.memory_per_worker - MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB") - value = int( - memory_per_worker[:-MEMORY_UNIT_LENGTH] - ) # Get numeric value by removing last 2 chars (e.g. "GB") - unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB") - assert len(unit) == MEMORY_UNIT_LENGTH, ( - f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}" - ) - total_memory = f"{value * parallel_mode_config.num_gpus_per_node}{unit}" - - self.cluster = SLURMCluster( - cores=parallel_mode_config.num_cpus_per_worker * num_gpus_per_node, - memory=total_memory, - processes=num_gpus_per_node, # Each job runs num_gpus_per_node dask workers (one per GPU) - walltime=parallel_mode_config.walltime, - job_extra_directives=[ - "--nodes=1", # Always request 1 node per job - "--exclusive", # Exclusive node access - "--mem=0", # Use all available memory - f"--cpus-per-task={parallel_mode_config.num_cpus_per_worker}", - f"-o {log_dir}/job_%j_task_%t.out", - f"-e {log_dir}/job_%j_task_%t.err", - ], - log_directory=log_dir, - ) - - # Scale jobs to the required number of jobs - self.cluster.scale(jobs=num_jobs) - self._initialize_workers(config, parallel_mode_config) - def _initialize_workers( - self, config: TransformerHandlerConfig, parallel_mode_config: ParallelModeConfig - ): - self.client = Client(self.cluster) - self.client.wait_for_workers(parallel_mode_config.num_workers) - - def get_cuda_visible_devices() -> int | None: - device = os.environ.get("CUDA_VISIBLE_DEVICES", None) - if device is not None: - # If has several devices, assume the first one is the one to use for that worker - if "," in device: - device = device.split(",", maxsplit=1)[0] - os.environ["CUDA_VISIBLE_DEVICES"] = device - return int(device) - return None - - worker_to_cuda_device = self.client.run(get_cuda_visible_devices) - workers_info = self.client.scheduler_info()["workers"] - sorted_workers = dict( - sorted(workers_info.items(), key=lambda item: item[1]["id"]) - ) - # The first worker is the master in the torch distributed setup - master_addr = next(iter(sorted_workers.values()))["host"] - - futures = [] - worker_ids = [] - for rank, (worker_address, worker_data) in enumerate(sorted_workers.items()): - worker_id = worker_data["id"] - # On some occasions, dask SLURM integration auto assigns CUDA_VISIBLE_DEVICES, otherwise we set it here - worker_cuda_device = worker_to_cuda_device[worker_address] - if worker_cuda_device is None: - worker_cuda_device = rank % parallel_mode_config.num_gpus_per_node - - parallel_worker_config = ParallelWorkerConfig( - rank=rank, - world_size=parallel_mode_config.num_workers, - local_rank=worker_cuda_device, - master_addr=master_addr, - master_port=parallel_mode_config.torch_port, - **parallel_mode_config.model_dump(), - ) - future_op = self.client.submit( - ParallelTransformerHandler, - config=config, - parallel_worker_config=parallel_worker_config, - workers=[worker_id], - actor=True, - ) - futures.append(future_op) - worker_ids.append(worker_id) - - self.actors: list[Actor] = self._client_gather(futures) - self.worker_ids = worker_ids - - async def __call__( - self, - inputs: str | BatchEncoding | list[dict] | None = None, - tools_json: list[dict] | None = None, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - **kwargs, - ) -> tuple[str, torch.Tensor]: - if inputs is None: - inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - inputs_tokenized = _get_tokenized_inputs(self.tokenizer, inputs, tools_json) - inputs_len = inputs_tokenized["input_ids"].shape[1] - outputs = await AsyncBufferedWorker.__call__(self, **inputs_tokenized, **kwargs) - AsyncTransformer._maybe_finalize_logits_processors( - kwargs.get("logits_processor"), outputs - ) - - return _process_outputs(self.config, self.tokenizer, outputs, inputs_len) - - async def _batched_call(self, batch_kwargs: dict[str, Any]): - return self._submit_and_gather( - self.handler_call_fn, **batch_kwargs, split_data=True - ) - - def _submit_and_gather( - self, - func: Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn] | str, - *args, - split_data: bool = False, - **kwargs, - ) -> list[TReturn]: - """Submit a function to all workers and gather the results. - - Args: - func: The function to send to each worker. If a string is provided, - then getattr(handler, func) is used. If func is not a string, - the first argument must be the ParallelTransformerHandler that it will - be executed on. - split_data: If True, split the data between workers. If False, - send the same data to all workers. - *args: Positional arguments to pass to the method. - **kwargs: Keyword arguments to pass to the method. - - Returns: - The gathered results from the workers. - """ - if split_data: - chunker = TensorChunker( - num_chunks=len(self.actors), - ) - split_args, split_kwargs, dummy_flags = chunker.chunkify(*args, **kwargs) - else: - split_args = [args] * len(self.actors) - split_kwargs = [kwargs] * len(self.actors) - - futures = [ - handler._exec_func( - func, - *args_i, - **kwargs_i, - ) - for handler, worker_id, args_i, kwargs_i in zip( - self.actors, self.worker_ids, split_args, split_kwargs, strict=True - ) - ] - results: list[TReturn] = self._client_gather(futures) - - if split_data: - return chunker.dechunkify(results, dummy_flags) - return results - - def wrap_afunc( - self, - func: Callable[ - Concatenate[ParallelTransformerHandler, TParams], Awaitable[TReturn] - ], - **kwargs, - ) -> Callable[TParams, Awaitable[TReturn]]: - raise NotImplementedError( - "ParallelAsyncTransformer does not implement wrap_afunc(). " - "Wrap a synchronous function with wrap_func() instead." - ) - - @overload - def wrap_func( - self, - *, - worker_agg_fn: Callable[[list[TReturn]], TReturn] | None = None, - **kwargs, - ) -> Callable[ - [Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn]], - Callable[TParams, TReturn], - ]: ... - - @overload - def wrap_func( - self, - func: Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn], - *, - worker_agg_fn: Callable[[list[TReturn]], TReturn] | None = None, - **kwargs, - ) -> Callable[TParams, TReturn]: ... - - def wrap_func( - self, - func: ( - Callable[Concatenate[ParallelTransformerHandler, TParams], TReturn] | None - ) = None, - *, - worker_agg_fn: Callable[[list[TReturn]], TReturn] | None = None, - **kwargs, - ) -> Callable: - """Wrap a function to execute on all workers and return gathered results. - - Args: - func: The function to wrap. - worker_agg_fn: A function to aggregate the results from all workers. - kwargs: Arguments that are discarded. Included here to enable a - subclass to add additional arguments. - """ - if worker_agg_fn is None: - raise ValueError("worker_agg_fn must be provided.") - - if func is None: - return partial(self.wrap_func, worker_agg_fn=worker_agg_fn, **kwargs) - - @wraps(func) - def wrapped_func(*args, **kwargs) -> TReturn: - return worker_agg_fn(self._submit_and_gather(func, *args, **kwargs)) - - return wrapped_func +class FSDP2ParallelAsyncTransformer(ParallelAsyncTransformer): + def _get_parallel_transformer_handler_cls(self): + return FSDP2ParallelTransformerHandler def state_dict(self, **kwargs) -> dict[str, torch.Tensor]: """Get consolidated state dict from all workers. @@ -611,99 +126,3 @@ def state_dict_worker( # Only need the state dict from rank 0 state_dict = self._submit_and_gather(state_dict_worker, **kwargs)[0] return {k: v.cpu() for k, v in state_dict.items()} - - def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: - # For some reason, Dask hangs when we pass a large object (e.g. state_dict) - # directly to the workers. I can replicate it with the following: - # - # @handler.wrap_func - # def hello(handler, _): - # print("hello") - # - # hello([0] * 1_000_000) - # NOTE: this does not seem to be FSDP-related, as the issue didn't go away when - # I disabled FSDP. - raise NotImplementedError( - "ParallelAsyncTransformer.load_state_dict() is not implemented yet. It is" - " recommended to use .save_checkpoint() and .load_checkpoint() instead. " - ) - - def load_checkpoint(self, ckpt: os.PathLike | str) -> None: - self._submit_and_gather("load_checkpoint", ckpt) - - def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: - self._submit_and_gather("save_checkpoint", ckpt, **kwargs) - - def teardown(self) -> None: - if self._initialized: - self.client.shutdown() - self.cluster.close() - del self.client - del self.cluster - self._initialized = False - - def __del__(self) -> None: - self.teardown() - - @staticmethod - def _wrap_dask_future(dask_future: ActorFuture): - """Converts a Dask ActorFuture into an awaitable asyncio.Future.""" - loop = asyncio.get_running_loop() - return asyncio.ensure_future(loop.run_in_executor(None, dask_future.result)) - - @staticmethod - def _raise_exceptions(done, pending, wrapped_futures): - exceptions = [] - for future in done: - exc = future.exception() - if exc: - exceptions.append(exc) - if exceptions: - if len(exceptions) == 1: - raise exceptions[0] - raise ExceptionGroup("Multiple actor exceptions", exceptions) - - if pending: - pending_indices = sorted([wrapped_futures.index(p) for p in pending]) - raise TimeoutError( - f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " - f"still pending. Pending task indices: {pending_indices}" - ) - - async def _client_gather_async(self, futures): - """Gather results from futures, propagating exceptions as they arrive. - - Unlike client.gather() which waits for all futures to complete before raising - any exceptions, this method processes futures as they complete and raises - exceptions immediately. This is crucial when using FSDP where workers may - be stuck waiting for each other when one worker crashes, causing long hangs. - - Note: Dask Actors currently have an issue where they're not working properly with - dask.gather() and can cause blocking issues or hide worker errors. This implementation - works around those limitations. - """ - try: - wrapped_futures = [self._wrap_dask_future(f) for f in futures] - - # Use asyncio.wait with FIRST_EXCEPTION instead of gather - done, pending = await asyncio.wait( - wrapped_futures, timeout=1200, return_when=asyncio.FIRST_EXCEPTION - ) - - self._raise_exceptions(done, pending, wrapped_futures) - - return await asyncio.gather(*wrapped_futures) - except Exception: - logger.exception("Error in dask workers: %s") - for future in wrapped_futures: - future.cancel() - self.teardown() - # sys.exit(1) would wait for dask to finish, which can cause hanging - # when workers are in a deadlock. Use os._exit to force immediate termination - # TODO: this is more of a hack, we should propagate special exception that is - # not caught by the rollout manager. - os._exit(1) - - def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: - # Use distributed.utils.sync to run the async function in the current thread - return sync(self.client.loop, self._client_gather_async, futures) # type: ignore[arg-type] diff --git a/uv.lock b/uv.lock index 46d1ee32..4758631d 100644 --- a/uv.lock +++ b/uv.lock @@ -724,15 +724,19 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/a6/83/ce29720ccf934c6cfa9b9c95ebbe96558386e66886626066632b5e44afed/dm_tree-0.1.9.tar.gz", hash = "sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b", size = 35623 } wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/b6/2d2de9f8901ccc5b6f34aea678e732816853015b9d756c86efcec189bf4b/dm_tree-0.1.9-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7d7d784afaeb4b67d87d858261aaf02503939ddc1f09c4cca70728f9892ab004", size = 173561 }, { url = "https://files.pythonhosted.org/packages/3e/07/57459f32cf5683c25b596ab58f42a3305f91876c2f03d2fa6e9d0df75fcb/dm_tree-0.1.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c", size = 146926 }, { url = "https://files.pythonhosted.org/packages/e8/46/939fbf81177c7cb3b1e5ddebd696237b3be9520769cce882f064de497103/dm_tree-0.1.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7", size = 152851 }, { url = "https://files.pythonhosted.org/packages/35/3e/a46933e0157b0ac87619a754ce1a796b2afc6386fca7c11f95c010f40745/dm_tree-0.1.9-cp311-cp311-win_amd64.whl", hash = "sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf", size = 101522 }, + { url = "https://files.pythonhosted.org/packages/ee/02/61aa90ab695918b4389d75c99bf0ec3cd0abacf1cadbef4053626f23ce34/dm_tree-0.1.9-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a8d20eeab7fde77a3ed71f07716021eb0edfb4812a128eb381d108af3a310257", size = 175012 }, { url = "https://files.pythonhosted.org/packages/81/10/120cd40556407879c1069941bd8b0d1a75754128c1a5bf0e27dbcf2a49fc/dm_tree-0.1.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80c43417814b1181d3367b335460bfdd30b79ee187a64220e11f6ddd093a4b15", size = 147204 }, { url = "https://files.pythonhosted.org/packages/86/52/27607a275c12858b979b8e943d2bd3bd0f9028503bb7079d5830a8b3cac0/dm_tree-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2334cfe9d2ed4293f9f1c7aefba0657deaab9ea74b5fadd966f6d01d9b6b42d9", size = 153013 }, { url = "https://files.pythonhosted.org/packages/ea/97/4f78412f73a9350bc8f934441bae5b68b102c8f4240a7f06b4114b51d6de/dm_tree-0.1.9-cp312-cp312-win_amd64.whl", hash = "sha256:9020a5ce256fcc83aa4bc190cc96dd66e87685db0a6e501b0c06aa492c2e38fc", size = 102022 }, + { url = "https://files.pythonhosted.org/packages/5f/13/823788cd0f7964cadcfa56d1e0f9e5e987ee73b5db6273bc00168f524f1a/dm_tree-0.1.9-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:cfa33c2e028155810ad1b4e11928707bf47489516763a86e79cab2954d23bf68", size = 175000 }, { url = "https://files.pythonhosted.org/packages/37/6a/512abdf7f20acc6cd6fce77f7663014d129aa313b5953aa2603d58fdb0c9/dm_tree-0.1.9-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d05622d074353cf434049206e53c12147903a048c4bd7d77f2800d427413ad78", size = 147210 }, { url = "https://files.pythonhosted.org/packages/e5/0a/f4d72ffb64ab3edc1fa66261f81ee3b4142ab14cd8aa1dfc7bbeca5ee4ba/dm_tree-0.1.9-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68b0efad76703dd4648586c75618a48cdd671b68c3266fe980e323c15423607", size = 153043 }, { url = "https://files.pythonhosted.org/packages/0d/ee/529ce999770b4d621a64af86c60cfee52f0cdd7294752105179ebf1c07c6/dm_tree-0.1.9-cp313-cp313-win_amd64.whl", hash = "sha256:e97c34fcb44941c36b7ee81dcdbceba0fbe728bddcc77e5837ab2eb665bcbff8", size = 102043 }, + { url = "https://files.pythonhosted.org/packages/ee/3c/5b40f8862390e9172e776cf610f3791c1af01f140a5698799fbe4a97206f/dm_tree-0.1.9-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b06e7a5da1c31a82521a60060573527e8d24b9920fdd20b2ec86f08412737598", size = 180821 }, { url = "https://files.pythonhosted.org/packages/84/1d/3cdbeeb3f6937a47a26cee502bffeccc2e55b97dfcce8a1d1135ea1b5b47/dm_tree-0.1.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6893fcdc5cf1a4f459cfc383526d35d42e7c671ae565d7e429a2f2cb2cb93e89", size = 147282 }, { url = "https://files.pythonhosted.org/packages/c5/37/15603079854394f16e3833a7b50696c1f3cbf30a2243a119f64f18a16f36/dm_tree-0.1.9-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c", size = 153052 }, ] @@ -1547,7 +1551,7 @@ requires-dist = [ { name = "tenacity" }, { name = "tiktoken" }, { name = "tokenizers", marker = "extra == 'nn'", specifier = ">0.20" }, - { name = "torch", marker = "extra == 'nn'", specifier = ">=2.5,<2.7" }, + { name = "torch", marker = "extra == 'nn'", specifier = ">=2.6,<2.7" }, { name = "tqdm" }, { name = "tqdm", marker = "extra == 'rich'", specifier = ">=4.56" }, { name = "transformers", marker = "extra == 'nn'", specifier = ">=4.46" }, @@ -2216,6 +2220,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, ] +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 }, +] + [[package]] name = "nvidia-nccl-cu12" version = "2.21.5" @@ -3647,7 +3659,7 @@ wheels = [ [[package]] name = "torch" -version = "2.5.1" +version = "2.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -3663,24 +3675,28 @@ dependencies = [ { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457", size = 906474467 }, - { url = "https://files.pythonhosted.org/packages/40/04/bd91593a4ca178ece93ca55f27e2783aa524aaccbfda66831d59a054c31e/torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9", size = 91919450 }, - { url = "https://files.pythonhosted.org/packages/0d/4a/e51420d46cfc90562e85af2fee912237c662ab31140ab179e49bd69401d6/torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a", size = 203098237 }, - { url = "https://files.pythonhosted.org/packages/d0/db/5d9cbfbc7968d79c5c09a0bc0bc3735da079f2fd07cc10498a62b320a480/torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c", size = 63884466 }, - { url = "https://files.pythonhosted.org/packages/8b/5c/36c114d120bfe10f9323ed35061bc5878cc74f3f594003854b0ea298942f/torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03", size = 906389343 }, - { url = "https://files.pythonhosted.org/packages/6d/69/d8ada8b6e0a4257556d5b4ddeb4345ea8eeaaef3c98b60d1cca197c7ad8e/torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697", size = 91811673 }, - { url = "https://files.pythonhosted.org/packages/5f/ba/607d013b55b9fd805db2a5c2662ec7551f1910b4eef39653eeaba182c5b2/torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c", size = 203046841 }, - { url = "https://files.pythonhosted.org/packages/57/6c/bf52ff061da33deb9f94f4121fde7ff3058812cb7d2036c97bc167793bd1/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1", size = 63858109 }, - { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 }, + { url = "https://files.pythonhosted.org/packages/78/a9/97cbbc97002fff0de394a2da2cdfa859481fdca36996d7bd845d50aa9d8d/torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1", size = 766715424 }, + { url = "https://files.pythonhosted.org/packages/6d/fa/134ce8f8a7ea07f09588c9cc2cea0d69249efab977707cf67669431dcf5c/torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d", size = 95759416 }, + { url = "https://files.pythonhosted.org/packages/11/c5/2370d96b31eb1841c3a0883a492c15278a6718ccad61bb6a649c80d1d9eb/torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7", size = 204164970 }, + { url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713 }, + { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563 }, + { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867 }, + { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469 }, + { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538 }, + { url = "https://files.pythonhosted.org/packages/24/85/ead1349fc30fe5a32cadd947c91bda4a62fbfd7f8c34ee61f6398d38fb48/torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf", size = 766626191 }, + { url = "https://files.pythonhosted.org/packages/dd/b0/26f06f9428b250d856f6d512413e9e800b78625f63801cbba13957432036/torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b", size = 95611439 }, + { url = "https://files.pythonhosted.org/packages/c2/9c/fc5224e9770c83faed3a087112d73147cd7c7bfb7557dcf9ad87e1dda163/torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc", size = 204126475 }, + { url = "https://files.pythonhosted.org/packages/88/8b/d60c0491ab63634763be1537ad488694d316ddc4a20eaadd639cedc53971/torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2", size = 66536783 }, ] [[package]] @@ -3745,14 +3761,12 @@ wheels = [ [[package]] name = "triton" -version = "3.1.0" +version = "3.2.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock", marker = "python_full_version < '3.13' and sys_platform == 'linux'" }, -] wheels = [ - { url = "https://files.pythonhosted.org/packages/86/17/d9a5cf4fcf46291856d1e90762e36cbabd2a56c7265da0d1d9508c8e3943/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c", size = 209506424 }, - { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 }, + { url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636 }, + { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365 }, + { url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 }, ] [[package]] From 76ee37f288820505a1343f57f66dc254802c9b4e Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Tue, 29 Apr 2025 18:31:12 -0500 Subject: [PATCH 33/34] nits --- src/ldp/nn/__init__.py | 4 ++-- src/ldp/nn/agent/simple_local_agent.py | 4 ++-- src/ldp/nn/graph/llm_call_op.py | 8 +++++--- src/ldp/nn/handlers/README_FSDP2.md | 23 ---------------------- src/ldp/nn/handlers/transformer_handler.py | 14 ++++++------- tests/test_nn_models.py | 4 ++++ 6 files changed, 20 insertions(+), 37 deletions(-) delete mode 100644 src/ldp/nn/handlers/README_FSDP2.md diff --git a/src/ldp/nn/__init__.py b/src/ldp/nn/__init__.py index 01a59172..83b7b3ae 100644 --- a/src/ldp/nn/__init__.py +++ b/src/ldp/nn/__init__.py @@ -12,7 +12,7 @@ ParallelTransformerHandler, TransformerHandler, TransformerHandlerConfig, - TransformerImplementation, + ParallelizationStrategy, collate_fn_transformer_left_pad, collate_fn_transformer_right_pad, decollate_fn_transformer_decoder, @@ -36,7 +36,7 @@ "TorchDType", "TransformerHandler", "TransformerHandlerConfig", - "TransformerImplementation", + "ParallelizationStrategy", "collate_fn_transformer_left_pad", "collate_fn_transformer_right_pad", "decollate_fn_transformer_decoder", diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index 2dd6ce12..af22251d 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -15,7 +15,7 @@ from ldp.nn.handlers.chunking import TensorChunker from ldp.nn.handlers.transformer_handler import ( ParallelModeConfig, - TransformerImplementation, + ParallelizationStrategy, logits_to_logprobs, ) from ldp.nn.lm_config import LMConfig as _LMConfig @@ -32,7 +32,7 @@ class AgentLMConfig(_LMConfig): # distribution parallel_mode: ParallelModeConfig | None = None - implementation: TransformerImplementation = TransformerImplementation.ACCELERATOR + implementation: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR # sampling parameters temperature: float = 1.0 max_new_tokens: int = 50 diff --git a/src/ldp/nn/graph/llm_call_op.py b/src/ldp/nn/graph/llm_call_op.py index 6581b1d7..f9b1907e 100644 --- a/src/ldp/nn/graph/llm_call_op.py +++ b/src/ldp/nn/graph/llm_call_op.py @@ -12,10 +12,11 @@ from ldp.graph.op_utils import CallID, get_call_id, get_training_mode from ldp.graph.ops import GradInType, Op, OpCtx, ResultOrValue from ldp.nn.handlers.transformer_handler import ( + AsyncTransformerInterface, LMType, ParallelModeConfig, TransformerHandlerConfig, - TransformerImplementation, + ParallelizationStrategy, collate_fn_transformer_left_pad, decollate_fn_transformer_decoder, ) @@ -40,7 +41,7 @@ def __init__( batch_size: int = 1, max_wait_interval: float = 0.1, parallel_mode_config: ParallelModeConfig | None = None, - implementation: TransformerImplementation = TransformerImplementation.ACCELERATOR, + implementation: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR, ) -> None: super().__init__() @@ -52,9 +53,10 @@ def __init__( batch_size=batch_size, max_wait_interval=max_wait_interval, parallel_mode_config=parallel_mode_config, - implementation=implementation, + parallel_strategy=implementation, # constant configuration lm_type=LMType.GENERATION, + module_call_fn=AsyncTransformerInterface.model_generate, collate_fn=partial( collate_fn_transformer_left_pad, pad_token_id=pad_token_id ), diff --git a/src/ldp/nn/handlers/README_FSDP2.md b/src/ldp/nn/handlers/README_FSDP2.md deleted file mode 100644 index 9bfea2ea..00000000 --- a/src/ldp/nn/handlers/README_FSDP2.md +++ /dev/null @@ -1,23 +0,0 @@ -# FSDP2 Implementation for Transformer Handler - -This implementation replaces the Accelerate-based FSDP implementation with PyTorch's native FSDP2 API. - -## Key Changes - -1. **Direct use of FSDP2 APIs**: - - - Uses `fully_shard()` from `torch.distributed.fsdp.fully_shard` instead of Accelerate's wrapper - - Registers model methods with `register_fsdp_forward_method` to ensure proper handling of model.generate() - -2. **Simplified Configuration**: - - - Uses native FSDP2 policies such as `MixedPrecisionPolicy` and `OffloadPolicy` - - Removed dependency on Accelerate-specific config formats - -3. **State Dict Management**: - - - With FSDP2, state dicts contain DTensors, which can be converted to full tensors when needed - - Added utility to consolidate DTensor state dicts for checkpointing - -4. **Code Reuse**: - - Imports utility functions when possible from the original `transformer_handler.py` file diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index 7dc14be7..6d6e283f 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -164,9 +164,9 @@ class LMType(StrEnum): REGRESSION = auto() -class TransformerImplementation(StrEnum): +class ParallelizationStrategy(StrEnum): ACCELERATOR = auto() # Current implementation using Accelerator - FSDP2 = auto() # New implementation using FSDP2 + FSDP2 = auto() # New implementation using vanilla FSDP2 class TransformerHandlerConfig(BaseModel): @@ -175,8 +175,8 @@ class TransformerHandlerConfig(BaseModel): lm_config: LMConfig lm_type: LMType checkpoint: str | None = None - implementation: TransformerImplementation = Field( - default=TransformerImplementation.ACCELERATOR, + parallel_strategy: ParallelizationStrategy = Field( + default=ParallelizationStrategy.ACCELERATOR, description="Which transformer implementation to use (Accelerator or FSDP2)", ) @@ -196,13 +196,13 @@ class TransformerHandlerConfig(BaseModel): def make_async_module(self, **kwargs) -> AsyncTransformerInterface: if self.parallel_mode_config: - if self.implementation == TransformerImplementation.ACCELERATOR: + if self.parallel_strategy == ParallelizationStrategy.ACCELERATOR: return ParallelAsyncTransformer(config=self, **kwargs) - if self.implementation == TransformerImplementation.FSDP2: + if self.parallel_strategy == ParallelizationStrategy.FSDP2: from .transformer_handler_fsdp2 import FSDP2ParallelAsyncTransformer return FSDP2ParallelAsyncTransformer(config=self, **kwargs) - raise ValueError(f"Unsupported implementation: {self.implementation}") + raise ValueError(f"Unsupported implementation: {self.parallel_strategy}") return AsyncTransformer(config=self, **kwargs) diff --git a/tests/test_nn_models.py b/tests/test_nn_models.py index 8e129e93..55917ce3 100644 --- a/tests/test_nn_models.py +++ b/tests/test_nn_models.py @@ -155,6 +155,7 @@ async def test_generation( lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=4, + module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=tokenizer.pad_token_id, @@ -200,6 +201,7 @@ def test_state_dicts( lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=1, + module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=model_config.get_tokenizer().pad_token_id, @@ -240,6 +242,7 @@ def test_distributed_checkpoints(self, sharded: bool) -> None: lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=1, + module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=model_config.get_tokenizer().pad_token_id, @@ -329,6 +332,7 @@ def test_consistent_weights(self): lm_config=model_config, lm_type=ldp.nn.LMType.GENERATION, batch_size=1, + module_call_fn=ldp.nn.AsyncTransformerInterface.model_generate, collate_fn=partial( ldp.nn.collate_fn_transformer_left_pad, pad_token_id=model_config.get_tokenizer().pad_token_id, From 877ad6bf261aa6ec5cfefe11ee2d8d9c598978d8 Mon Sep 17 00:00:00 2001 From: Ori Kabeli Date: Wed, 30 Apr 2025 08:15:24 -0500 Subject: [PATCH 34/34] nits --- src/ldp/nn/agent/simple_local_agent.py | 4 ++-- src/ldp/nn/graph/llm_call_op.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index af22251d..1c67b265 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -32,7 +32,7 @@ class AgentLMConfig(_LMConfig): # distribution parallel_mode: ParallelModeConfig | None = None - implementation: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR + parallel_strategy: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR # sampling parameters temperature: float = 1.0 max_new_tokens: int = 50 @@ -81,7 +81,7 @@ def __init__( batch_size=self.llm_model.batch_size, max_wait_interval=self.llm_model.max_wait_interval, parallel_mode_config=self.llm_model.parallel_mode, - implementation=self.llm_model.implementation, + parallel_strategy=self.llm_model.parallel_strategy, ) async def init_state(self, tools: list[Tool]) -> SimpleAgentState: diff --git a/src/ldp/nn/graph/llm_call_op.py b/src/ldp/nn/graph/llm_call_op.py index f9b1907e..cd98275e 100644 --- a/src/ldp/nn/graph/llm_call_op.py +++ b/src/ldp/nn/graph/llm_call_op.py @@ -41,7 +41,7 @@ def __init__( batch_size: int = 1, max_wait_interval: float = 0.1, parallel_mode_config: ParallelModeConfig | None = None, - implementation: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR, + parallel_strategy: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR, ) -> None: super().__init__() @@ -53,7 +53,7 @@ def __init__( batch_size=batch_size, max_wait_interval=max_wait_interval, parallel_mode_config=parallel_mode_config, - parallel_strategy=implementation, + parallel_strategy=parallel_strategy, # constant configuration lm_type=LMType.GENERATION, module_call_fn=AsyncTransformerInterface.model_generate,