Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dev/sft/sft-from-file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
import random

import art
from art.local import LocalBackend
from art.megatron import MegatronBackend
from art.utils.sft import train_sft_from_file


async def main():
backend = LocalBackend()
backend = MegatronBackend()

model_name = "run-" + "".join(
random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)
)
model = art.TrainableModel(
name=model_name,
project="sft-from-file",
base_model="meta-llama/Llama-3.1-8B-Instruct",
base_model="Qwen/Qwen3.6-35B-A3B",
)
await model.register(backend)

Expand Down
4 changes: 2 additions & 2 deletions dev/sft/sft-warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dotenv import load_dotenv

import art
from art.local import LocalBackend
from art.megatron import MegatronBackend
from art.utils.sft import create_sft_dataset_iterator

# Simple SFT trajectories - teach model to respond "maybe"
Expand Down Expand Up @@ -43,7 +43,7 @@ async def rl_rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
async def main():
load_dotenv()

backend = LocalBackend()
backend = MegatronBackend()
model_name = "sft-warmup-" + "".join(
random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)
)
Expand Down
4 changes: 4 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing_extensions import TypedDict

from art.types import SFTMetricLoggingConfig

if TYPE_CHECKING:
from art.megatron.routing_replay import MoeRoutingReplayBundle

Expand Down Expand Up @@ -40,3 +42,5 @@ class TrainConfig(TypedDict, total=False):

class TrainSFTConfig(TypedDict, total=False):
"""Experimental SFT configuration options. Use at your own risk."""

metric_logging: SFTMetricLoggingConfig
3 changes: 3 additions & 0 deletions src/art/metrics_taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from .trajectories import TrajectoryGroup

TRAIN_GRADIENT_STEPS_KEY = "data/step_num_gradient_steps"
SFT_METRIC_PREFIX = "sft"
SFT_GRADIENT_STEP_KEY = "gradient_step"
SFT_WANDB_GRADIENT_STEP_KEY = f"{SFT_METRIC_PREFIX}/{SFT_GRADIENT_STEP_KEY}"
_INVARIANT_METRIC_KEYS = frozenset({TRAIN_GRADIENT_STEPS_KEY})


Expand Down
74 changes: 68 additions & 6 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from .costs import CostCalculator
from .metrics import MetricsBuilder, is_builder_managed_metric
from .metrics_taxonomy import (
SFT_GRADIENT_STEP_KEY,
SFT_METRIC_PREFIX,
SFT_WANDB_GRADIENT_STEP_KEY,
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
build_data_metrics_from_summary,
summarize_trajectory_groups,
)
from .preprocessing.moe_routing import attach_moe_routing_metadata_to_choice
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainSFTConfig
from .types import SFTMetricLoggingConfig, TrainSFTConfig
from .utils.trajectory_logging import write_trajectory_groups_parquet

if TYPE_CHECKING:
Expand Down Expand Up @@ -625,6 +628,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
self,
"_wandb_defined_metrics",
{
SFT_WANDB_GRADIENT_STEP_KEY,
"training_step",
"time/wall_clock_sec",
},
Expand All @@ -634,12 +638,16 @@ def _get_wandb_run(self) -> Optional["Run"]:
# This allows out-of-order logging (e.g., async validation for previous steps).
run.define_metric("training_step")
run.define_metric("time/wall_clock_sec")
run.define_metric(SFT_WANDB_GRADIENT_STEP_KEY)
run.define_metric("reward/*", step_metric="training_step")
run.define_metric("loss/*", step_metric="training_step")
run.define_metric("throughput/*", step_metric="training_step")
run.define_metric("costs/*", step_metric="training_step")
run.define_metric("time/*", step_metric="training_step")
run.define_metric("data/*", step_metric="training_step")
run.define_metric(
f"{SFT_METRIC_PREFIX}/*", step_metric=SFT_WANDB_GRADIENT_STEP_KEY
)
run.define_metric("train/*", step_metric="training_step")
run.define_metric("val/*", step_metric="training_step")
run.define_metric("test/*", step_metric="training_step")
Expand Down Expand Up @@ -1230,6 +1238,7 @@ async def train_sft(
config: TrainSFTConfig | None = None,
_config: dev.TrainSFTConfig | None = None,
verbose: bool = False,
log_metrics: bool = True,
) -> None:
"""
Supervised fine-tune the model with an iterable of trajectories.
Expand All @@ -1241,31 +1250,84 @@ async def train_sft(
_config: Additional experimental configuration that is subject to change and
not yet part of the public API. Use at your own risk.
verbose: Whether to print verbose output.
log_metrics: Whether to log SFT optimizer metrics. Defaults to True.
"""
if config is None:
config = TrainSFTConfig()

backend = self.backend()
backend_logs_sft_metrics = (
log_metrics and self._backend_logs_sft_metrics_remotely(backend)
)

_config = cast(dev.TrainSFTConfig, {**(_config or {})})
if log_metrics:
metric_logging_config: SFTMetricLoggingConfig = {
"enabled": True,
}
if backend_logs_sft_metrics:
metric_logging_config["target_training_step"] = (
await self.get_step()
) + 1
_config["metric_logging"] = metric_logging_config
else:
_config["metric_logging"] = {"enabled": False}

# Train (backend yields metrics for each batch without logging)
# Collect all metrics and aggregate them at the end (same as RL)
_config = _config or {} # ty:ignore[invalid-assignment]
# Collect all metrics and aggregate them at the end for the checkpoint summary.
training_metrics: list[dict[str, float]] = []
local_sft_checkpoint_step: int | None = None
trainer_started = time.monotonic()
async for metrics in self.backend()._train_sft(
async for metrics in backend._train_sft(
self,
trajectories,
config,
_config, # ty:ignore[invalid-argument-type]
verbose,
):
training_metrics.append(metrics)
gradient_step = len(training_metrics)
if log_metrics and not backend_logs_sft_metrics:
if local_sft_checkpoint_step is None:
local_sft_checkpoint_step = await self.get_step() + 1
await self._log_sft_metric_sample(
metrics,
checkpoint_step=local_sft_checkpoint_step,
gradient_step=gradient_step,
)
trainer_elapsed = time.monotonic() - trainer_started

# Log aggregated training metrics once (same as RL)
if training_metrics:
# Log aggregated training metrics once at the checkpoint step. For
# remote-logging backends, the remote SFT job owns this row too.
if training_metrics and log_metrics and not backend_logs_sft_metrics:
avg_metrics = average_metric_samples(training_metrics)
avg_metrics["time/step_trainer_s"] = trainer_elapsed
# Get the current step after training
step = await self.get_step()
await self.log(
trajectories=None, split="train", metrics=avg_metrics, step=step
)

@staticmethod
def _backend_logs_sft_metrics_remotely(backend: "Backend") -> bool:
remote_logger = getattr(type(backend), "logs_sft_metrics_remotely", None)
if not callable(remote_logger):
return False
return bool(remote_logger(backend))

async def _log_sft_metric_sample(
self,
metrics: dict[str, float],
*,
checkpoint_step: int,
gradient_step: int,
) -> None:
await self.log(
trajectories=None,
split=SFT_METRIC_PREFIX,
metrics={
SFT_GRADIENT_STEP_KEY: float(gradient_step),
**metrics,
},
step=checkpoint_step,
)
18 changes: 16 additions & 2 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from contextlib import asynccontextmanager
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal, cast
import warnings

from openai._types import NOT_GIVEN
Expand All @@ -22,7 +22,12 @@
summarize_trajectory_groups,
)
from ..trajectories import Trajectory, TrajectoryGroup
from ..types import ServerlessTrainResult, TrainConfig, TrainSFTConfig
from ..types import (
ServerlessTrainResult,
SFTMetricLoggingConfig,
TrainConfig,
TrainSFTConfig,
)
from ..utils.record_provenance import record_provenance

if TYPE_CHECKING:
Expand Down Expand Up @@ -88,6 +93,9 @@ def __init__(
self._base_url = str(client.base_url)
self._client = client

def logs_sft_metrics_remotely(self) -> bool:
return True

async def close(self) -> None:
await self._client.close() # ty:ignore[possibly-missing-attribute]

Expand Down Expand Up @@ -607,6 +615,12 @@ async def _train_sft(
)
sft_config["batch_size"] = batch_size
sft_config["learning_rate"] = config.learning_rate
metric_logging = cast(
SFTMetricLoggingConfig,
dict(dev_config.get("metric_logging", {}) or {}),
)
if metric_logging.get("enabled"):
sft_config["metric_logging"] = metric_logging

sft_training_job = await self._client.sft_training_jobs.create(
model_id=model.id,
Expand Down
2 changes: 2 additions & 0 deletions src/art/serverless/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing_extensions import override

from ..trajectories import TrajectoryGroup
from ..types import SFTMetricLoggingConfig

ResponseT = TypeVar("ResponseT")

Expand Down Expand Up @@ -80,6 +81,7 @@ class ExperimentalTrainingConfig(TypedDict, total=False):
class SFTTrainingConfig(TypedDict, total=False):
batch_size: int | None
learning_rate: float | list[float] | None
metric_logging: SFTMetricLoggingConfig | None


class TrainingJob(BaseModel):
Expand Down
6 changes: 6 additions & 0 deletions src/art/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
import pydantic
from pydantic import SkipValidation
from typing_extensions import TypedDict

Message = Annotated[ChatCompletionMessageParam, SkipValidation]
MessageOrChoice = Message | Choice
Expand All @@ -25,6 +26,11 @@ class TrainSFTConfig(pydantic.BaseModel):
batch_size: int | Literal["auto"] = "auto"


class SFTMetricLoggingConfig(TypedDict, total=False):
enabled: bool
target_training_step: int


Verbosity = Literal[0, 1, 2]


Expand Down
3 changes: 3 additions & 0 deletions src/art/utils/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ async def train_sft_from_file(
_config: "DevTrainSFTConfig | None" = None,
verbose: bool = False,
shuffle_buffer_size: int = 10000,
log_metrics: bool = True,
) -> None:
"""
Train a model using supervised fine-tuning from a JSONL file.
Expand All @@ -375,6 +376,7 @@ async def train_sft_from_file(
verbose: Whether to print verbose output. Default: False
shuffle_buffer_size: Size of shuffle buffer. Default: 10000.
Larger values give better shuffling but use more memory.
log_metrics: Whether to log SFT optimizer metrics. Default: True.

Example:
await train_sft_from_file(
Expand Down Expand Up @@ -449,4 +451,5 @@ async def train_sft_from_file(
config,
_config=_config,
verbose=verbose,
log_metrics=log_metrics,
)
Loading
Loading