diff --git a/examples/configs/sdpo_gsm8k_1B.yaml b/examples/configs/sdpo_gsm8k_1B.yaml
new file mode 100644
index 0000000000..63b4281f48
--- /dev/null
+++ b/examples/configs/sdpo_gsm8k_1B.yaml
@@ -0,0 +1,294 @@
+# SDPO on GSM8K with Qwen2.5-1.5B-Instruct
+#
+# Self-Distilled Policy Optimization (SDPO):
+# arXiv:2601.20802, Hübotter et al. 2026
+#
+# For each prompt group the model's own high-reward rollouts serve as
+# "demonstrations". The loss is a token-level reverse-KL between the
+# policy conditioned on the demonstration (teacher) and the policy without
+# the demonstration (student).
+
+sdpo:
+ num_prompts_per_step: 16 # number of distinct prompts per step
+ num_generations_per_prompt: 8 # generations per prompt (need > 1 for demos)
+ max_rollout_turns: 1
+ max_num_epochs: 1
+ max_num_steps: 1000000
+ val_period: 10
+ val_at_start: false
+ val_at_end: false
+ max_val_samples: 256
+ val_batch_size: 256
+ seed: 42
+ # SDPO-specific
+ success_reward_threshold: 1.0 # reward must equal 1.0 to count as "successful"
+ max_reprompt_len: 4096 # max tokens in the reprompted teacher prompt
+ remove_thinking_from_demo: false # strip … from demonstrations
+ dont_reprompt_on_self_success: true # don't use own success as demo
+
+loss_fn:
+ success_reward_threshold: 1.0 # kept in sync with sdpo.success_reward_threshold
+ is_clip: 2.0 # clip IS ratio; null to disable
+
+checkpointing:
+ enabled: true
+ checkpoint_dir: "results/sdpo_gsm8k"
+ metric_name: "val:accuracy"
+ higher_is_better: true
+ keep_top_k: 3
+ save_period: 10
+ checkpoint_must_save_by: null
+ model_save_format: "safetensors"
+ save_consolidated: false
+ save_optimizer: true
+
+policy:
+ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
+ tokenizer:
+ name: ${policy.model_name}
+ chat_template_kwargs: null
+ hf_config_overrides: {}
+ train_global_batch_size: 128 # 16 prompts × 8 gens
+ train_micro_batch_size: 4
+ generation_batch_size: 32
+ logprob_batch_size: ${policy.train_micro_batch_size}
+ max_total_sequence_length: 1024
+ precision: "bfloat16"
+ logprob_chunk_size: null
+ offload_optimizer_for_logprob: false
+
+ dtensor_cfg:
+ _v2: true
+ enabled: true
+ cpu_offload: False
+ sequence_parallel: false
+ activation_checkpointing: false
+ tensor_parallel_size: 1
+ context_parallel_size: 1
+ custom_parallel_plan: null
+ automodel_kwargs: {}
+ lora_cfg:
+ enabled: False
+ target_modules: []
+ exclude_modules: []
+ match_all_linear: true
+ dim: 8
+ alpha: 32
+ dropout: 0.0
+ dropout_position: "post"
+ lora_A_init: "xavier"
+ use_triton: true
+
+ megatron_cfg:
+ enabled: false
+ force_reconvert_from_hf: False
+ empty_unused_memory_level: 1
+ activation_checkpointing: false
+ converter_type: "Qwen2ForCausalLM"
+ tensor_model_parallel_size: 1
+ expert_tensor_parallel_size: 1
+ expert_model_parallel_size: 1
+ pipeline_model_parallel_size: 1
+ num_layers_in_first_pipeline_stage: null
+ num_layers_in_last_pipeline_stage: null
+ context_parallel_size: 1
+ pipeline_dtype: ${policy.precision}
+ sequence_parallel: false
+ freeze_moe_router: true
+ moe_router_dtype: "fp64"
+ moe_router_load_balancing_type: "none"
+ moe_router_bias_update_rate: 0.0
+ moe_permute_fusion: false
+ apply_rope_fusion: True
+ bias_activation_fusion: True
+ defer_fp32_logits: False
+ moe_per_layer_logging: False
+ moe_enable_deepep: false
+ moe_token_dispatcher_type: "alltoall"
+ moe_shared_expert_overlap: false
+
+ peft:
+ enabled: false
+ target_modules: []
+ exclude_modules: []
+ dim: 8
+ alpha: 32
+ dropout: 0.0
+ dropout_position: "post"
+ lora_A_init_method: "xavier"
+ lora_B_init_method: "zero"
+ a2a_experimental: false
+ lora_dtype: None
+
+ optimizer:
+ optimizer: "adam"
+ lr: 1.0e-5
+ min_lr: 1.0e-6
+ weight_decay: 0.01
+ bf16: true
+ fp16: false
+ params_dtype: "float32"
+ adam_beta1: 0.9
+ adam_beta2: 0.999
+ adam_eps: 1e-8
+ sgd_momentum: 0.9
+ use_distributed_optimizer: true
+ use_precision_aware_optimizer: true
+ clip_grad: ${policy.max_grad_norm}
+ optimizer_cpu_offload: false
+ optimizer_offload_fraction: 0.0
+
+ scheduler:
+ start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
+ end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
+ weight_decay_incr_style: "constant"
+ lr_decay_style: "constant"
+ lr_decay_iters: 1000
+ lr_warmup_iters: 10
+ lr_warmup_init: 1.0e-6
+
+ distributed_data_parallel_config:
+ grad_reduce_in_fp32: false
+ overlap_grad_reduce: true
+ overlap_param_gather: true
+ use_custom_fsdp: false
+ data_parallel_sharding_strategy: "optim_grads_params"
+
+ fp8_cfg:
+ enabled: false
+ fp8: "e4m3"
+ fp8_recipe: "blockwise"
+ fp8_param: false
+
+ env_vars: null
+
+ draft:
+ enabled: false
+ model_name: null
+ loss_weight: 0.1
+ num_layers: null
+ aux_layer_indices: null
+
+ dynamic_batching:
+ enabled: False
+ train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
+ logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
+ sequence_length_round: 64
+
+ sequence_packing:
+ enabled: True
+ train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
+ logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
+ algorithm: "modified_first_fit_decreasing"
+ sequence_length_round: 64
+
+ make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
+ max_grad_norm: 1.0
+
+ optimizer:
+ name: "torch.optim.AdamW"
+ kwargs:
+ lr: 1.0e-5
+ weight_decay: 0.01
+ betas: [0.9, 0.999]
+ eps: 1e-8
+ foreach: False
+ fused: False
+
+ scheduler:
+ - name: "torch.optim.lr_scheduler.LinearLR"
+ kwargs:
+ start_factor: 0.1
+ end_factor: 1.0
+ total_iters: 10
+ - name: "torch.optim.lr_scheduler.ConstantLR"
+ kwargs:
+ factor: 1.0
+ total_iters: 10000000000
+ - milestones: [10]
+
+ generation:
+ backend: "vllm"
+ max_new_tokens: ${policy.max_total_sequence_length}
+ temperature: 1.0
+ top_p: 1.0
+ top_k: null
+ stop_token_ids: null
+ stop_strings: null
+ mcore_generation_config:
+ buffer_size_gb: 10
+ num_cuda_graphs: 4
+ block_size_tokens: 256
+ use_cuda_graphs_for_non_decode_steps: true
+ enable_chunked_prefill: true
+ unified_memory_level: 0
+ max_tokens: 16384
+ vllm_cfg:
+ async_engine: false
+ precision: ${policy.precision}
+ kv_cache_dtype: "auto"
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+ expert_parallel_size: 1
+ gpu_memory_utilization: 0.6
+ max_model_len: ${policy.max_total_sequence_length}
+ enforce_eager: False
+ use_deep_gemm: False
+ num_last_layers_in_bf16: 0
+ num_first_layers_in_bf16: 0
+ enable_vllm_metrics_logger: true
+ vllm_metrics_logger_interval: 0.5
+ vllm_kwargs: {}
+ colocated:
+ enabled: true
+ resources:
+ gpus_per_node: null
+ num_nodes: null
+
+data:
+ max_input_seq_length: ${policy.max_total_sequence_length}
+ shuffle: true
+ num_workers: 1
+ use_multiple_dataloader: false
+ train:
+ dataset_name: gsm8k
+ validation:
+ dataset_name: gsm8k
+ split: "test"
+ default:
+ prompt_file: "examples/prompts/gsm8k.txt"
+ system_prompt_file: null
+ processor: "math_hf_data_processor"
+ env_name: "math"
+
+env:
+ math:
+ num_workers: 8
+ math_verify_impl: "hf_math_verify"
+
+logger:
+ log_dir: "logs/sdpo_gsm8k"
+ num_val_samples_to_print: 0
+ wandb_enabled: false
+ tensorboard_enabled: false
+ mlflow_enabled: false
+ swanlab_enabled: false
+ monitor_gpus: true
+ wandb:
+ project: "sdpo-dev"
+ name: "sdpo-gsm8k-1B"
+ swanlab:
+ project: "sdpo-dev"
+ name: "sdpo-gsm8k-1B"
+ tensorboard: {}
+ mlflow:
+ experiment_name: "sdpo-dev"
+ run_name: "sdpo-gsm8k-1B"
+ tracking_uri: "http://localhost:5000"
+ gpu_monitoring:
+ collection_interval: 10
+ flush_interval: 10
+
+cluster:
+ gpus_per_node: 1
+ num_nodes: 1
diff --git a/examples/run_sdpo.py b/examples/run_sdpo.py
new file mode 100644
index 0000000000..a237dbd1d0
--- /dev/null
+++ b/examples/run_sdpo.py
@@ -0,0 +1,106 @@
+import argparse
+import os
+import pprint
+
+from omegaconf import OmegaConf
+
+from nemo_rl.algorithms.sdpo import MasterConfig, sdpo_train, setup
+from nemo_rl.algorithms.utils import get_tokenizer
+from nemo_rl.data.utils import setup_response_data
+from nemo_rl.distributed.virtual_cluster import init_ray
+from nemo_rl.models.generation import configure_generation_config
+from nemo_rl.utils.config import (
+ load_config,
+ parse_hydra_overrides,
+ register_omegaconf_resolvers,
+)
+from nemo_rl.utils.logger import get_next_experiment_dir
+
+
+def parse_args() -> tuple[argparse.Namespace, list[str]]:
+ parser = argparse.ArgumentParser(description="Run SDPO training")
+ parser.add_argument(
+ "--config", type=str, default=None, help="Path to YAML config file"
+ )
+ args, overrides = parser.parse_known_args()
+ return args, overrides
+
+
+def main() -> None:
+ register_omegaconf_resolvers()
+ args, overrides = parse_args()
+
+ if not args.config:
+ args.config = os.path.join(
+ os.path.dirname(__file__), "configs", "sdpo_gsm8k_1B.yaml"
+ )
+
+ config = load_config(args.config)
+ print(f"Loaded configuration from: {args.config}")
+
+ if overrides:
+ print(f"Overrides: {overrides}")
+ config = parse_hydra_overrides(config, overrides)
+
+ config: MasterConfig = OmegaConf.to_container(config, resolve=True)
+ print("Applied CLI overrides")
+
+ print("Final config:")
+ pprint.pprint(config)
+
+ config["logger"]["log_dir"] = get_next_experiment_dir(
+ config["logger"]["log_dir"]
+ )
+ print(f"Using log directory: {config['logger']['log_dir']}")
+
+ init_ray()
+
+ tokenizer = get_tokenizer(config["policy"]["tokenizer"])
+ assert config["policy"]["generation"] is not None, (
+ "A generation config is required for SDPO"
+ )
+ config["policy"]["generation"] = configure_generation_config(
+ config["policy"]["generation"],
+ tokenizer,
+ has_refit_draft_weights=False,
+ )
+
+ (
+ dataset,
+ val_dataset,
+ task_to_env,
+ val_task_to_env,
+ ) = setup_response_data(tokenizer, config["data"], config["env"])
+
+ (
+ policy,
+ policy_generation,
+ cluster,
+ dataloader,
+ val_dataloader,
+ loss_fn,
+ logger,
+ checkpointer,
+ sdpo_state,
+ master_config,
+ ) = setup(config, tokenizer, dataset, val_dataset)
+
+ print("Running SDPO training")
+ sdpo_train(
+ policy,
+ policy_generation,
+ dataloader,
+ val_dataloader,
+ tokenizer,
+ loss_fn,
+ task_to_env,
+ val_task_to_env,
+ logger,
+ checkpointer,
+ sdpo_state,
+ master_config,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py
index ede13323b1..69cef42251 100644
--- a/nemo_rl/algorithms/loss/__init__.py
+++ b/nemo_rl/algorithms/loss/__init__.py
@@ -26,6 +26,9 @@
NLLLossFn,
PreferenceLossDataDict,
PreferenceLossFn,
+ SDPOLossConfig,
+ SDPOLossDataDict,
+ SDPOLossFn,
)
from nemo_rl.algorithms.loss.utils import (
prepare_loss_input,
@@ -51,6 +54,9 @@
"NLLLossFn",
"PreferenceLossDataDict",
"PreferenceLossFn",
+ "SDPOLossConfig",
+ "SDPOLossDataDict",
+ "SDPOLossFn",
"prepare_loss_input",
"prepare_packed_loss_input",
"SequencePackingFusionLossWrapper",
diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py
index df6ff6bc54..fd5e429240 100755
--- a/nemo_rl/algorithms/loss/loss_functions.py
+++ b/nemo_rl/algorithms/loss/loss_functions.py
@@ -1035,3 +1035,124 @@ def __call__(
}
return kl_loss, metrics
+
+
+# ============================================================================
+# SDPO Loss
+# ============================================================================
+
+
+class SDPOLossConfig(TypedDict):
+ """Configuration for the SDPO (Self-Distilled Policy Optimization) loss.
+
+ Defaults:
+ success_reward_threshold: 1.0 - minimum reward to count as "successful"
+ is_clip: 2.0 - clip value for token-level IS ratio (None to disable)
+ """
+
+ success_reward_threshold: float
+ is_clip: Optional[float]
+
+
+class SDPOLossDataDict(TypedDict):
+ """Required keys in the data BatchedDataDict for SDPOLossFn."""
+
+ input_ids: torch.Tensor # [B, seq_len]
+ token_mask: torch.Tensor # [B, seq_len] 1 = response token
+ sample_mask: torch.Tensor # [B] 1 = valid sample
+ prev_logprobs: torch.Tensor # [B, seq_len] old policy logprobs (for IS correction)
+ teacher_logprobs: torch.Tensor # [B, seq_len] teacher logprobs at student positions
+ sdpo_mask: torch.Tensor # [B] 1 = sample has a demonstration
+
+
+class SDPOLossFn(LossFunction):
+ """Self-Distilled Policy Optimization loss.
+
+ Trains the student (current policy on original prompt) to match the teacher
+ (current policy conditioned on a successful demonstration) via a token-level
+ reverse-KL distillation objective:
+
+ L(θ) = E_{t ∈ response, i has demo}
+ [ (log π_θ(t|s) - log π_teacher(t|s, demo)).detach()
+ * log π_θ(t|s)
+ * IS_clip(π_θ / π_old, c) ]
+
+ where:
+ - π_θ(t|s) is the current student policy (being optimised)
+ - π_teacher(t|s,demo) is the current model conditioned on the demonstration
+ (pre-computed and stored in data["teacher_logprobs"])
+ - IS_clip clips the importance-sampling ratio to [0, is_clip] to stabilise
+ off-policy multi-step updates
+
+ References:
+ Hübotter et al. (2026) "Reinforcement Learning via Self-Distillation"
+ arXiv:2601.20802
+ """
+
+ loss_type = LossType.TOKEN_LEVEL
+ input_type = LossInputType.LOGPROB
+
+ def __init__(self, cfg: SDPOLossConfig):
+ self.is_clip = cfg.get("is_clip", 2.0)
+
+ def __call__(
+ self,
+ next_token_logprobs: Tensor,
+ data: BatchedDataDict,
+ global_valid_seqs: torch.Tensor,
+ global_valid_toks: torch.Tensor,
+ ) -> tuple[torch.Tensor, dict[str, Any]]:
+ """Compute the SDPO self-distillation loss.
+
+ Args:
+ next_token_logprobs: current-policy log-probs, shape [B, seq_len].
+ data: must contain keys defined in SDPOLossDataDict.
+ global_valid_seqs: number of valid sequences in this microbatch.
+ global_valid_toks: number of valid tokens in this microbatch.
+
+ Returns:
+ (loss, metrics)
+ """
+ # next_token_logprobs from the training forward has shape [B, S-1]
+ # (convention: logprobs[t] = log P(token[t+1] | context[0:t+1])).
+ # teacher_logprobs / prev_logprobs come from get_logprobs() which uses the
+ # full-sequence convention [B, S] with a dummy 0 at position 0, so we shift
+ # those by 1 to align with next_token_logprobs.
+ student_lp = next_token_logprobs # [B, S-1]
+ teacher_lp = data["teacher_logprobs"][:, 1:] # [B, S-1]
+ token_mask = data["token_mask"][:, 1:] # [B, S-1]
+ sdpo_mask = data["sdpo_mask"] # [B]
+ sample_mask = data["sample_mask"] # [B]
+
+ # Effective mask: response token AND sample has demo AND sample is valid
+ effective_mask = (
+ token_mask
+ * sdpo_mask.unsqueeze(-1).float()
+ * sample_mask.unsqueeze(-1)
+ )
+
+ # Token-level reverse-KL gradient (REINFORCE approximation):
+ # ∇_θ KL(π_θ || π_teacher) ≈ (log_ratio).detach() * ∇_θ log π_θ
+ log_ratio = (student_lp - teacher_lp).detach()
+ per_token_loss = log_ratio * student_lp
+
+ # Optional importance-sampling correction for off-policy updates
+ if self.is_clip is not None:
+ prev_lp = data["prev_logprobs"][:, 1:] # [B, S-1]
+ is_ratio = (student_lp - prev_lp).detach().exp().clamp(max=self.is_clip)
+ per_token_loss = per_token_loss * is_ratio
+
+ loss = masked_mean(
+ per_token_loss,
+ effective_mask,
+ global_normalization_factor=global_valid_toks,
+ )
+
+ frac_with_demo = sdpo_mask.float().mean().item()
+ metrics = {
+ "num_valid_samples": sample_mask.sum().item(),
+ "sdpo/mean_log_ratio": masked_mean(log_ratio, effective_mask).item(),
+ "sdpo/frac_with_demo": frac_with_demo,
+ }
+
+ return loss, metrics
diff --git a/nemo_rl/algorithms/sdpo.py b/nemo_rl/algorithms/sdpo.py
new file mode 100644
index 0000000000..202b08547a
--- /dev/null
+++ b/nemo_rl/algorithms/sdpo.py
@@ -0,0 +1,953 @@
+"""Self-Distilled Policy Optimization (SDPO).
+
+Reference: Hübotter et al. (2026) "Reinforcement Learning via Self-Distillation"
+ arXiv:2601.20802
+
+This module implements SDPO in NeMo-RL. The key idea is:
+ - Roll out the policy (same as GRPO).
+ - For each prompt group find successful responses (reward >= threshold).
+ - For every sample build a "teacher" input: original prompt prepended with a
+ successful demonstration. If no demonstration is available the sample is
+ excluded from the distillation loss.
+ - Compute teacher log-probs (current model, enriched context) and align them
+ to the student sequence positions.
+ - Train with token-level reverse-KL distillation (SDPOLossFn) instead of a
+ clipped policy-gradient loss.
+"""
+
+import os
+import re
+import time
+from typing import Any, NotRequired, Optional, TypedDict, TypeVar
+
+import torch
+from torchdata.stateful_dataloader import StatefulDataLoader
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+
+from nemo_rl.algorithms.grpo import (
+ GRPOLoggerConfig,
+ _extract_prompt_only_messages,
+ _should_use_async_rollouts,
+ refit_policy_generation,
+ validate,
+)
+from nemo_rl.algorithms.loss import (
+ SDPOLossConfig,
+ SDPOLossFn,
+)
+from nemo_rl.algorithms.loss.interfaces import LossFunction
+from nemo_rl.algorithms.utils import set_seed
+from nemo_rl.data import DataConfig
+from nemo_rl.data.collate_fn import rl_collate_fn
+from nemo_rl.data.dataloader import MultipleDataloaderWrapper
+from nemo_rl.data.datasets import AllTaskProcessedDataset
+from nemo_rl.data.interfaces import DatumSpec
+from nemo_rl.data.llm_message_utils import (
+ batched_message_log_to_flat_message,
+ get_keys_from_message_log,
+)
+from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
+from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
+from nemo_rl.environments.interfaces import EnvironmentInterface
+from nemo_rl.experience.rollouts import (
+ run_async_multi_turn_rollout,
+ run_multi_turn_rollout,
+)
+from nemo_rl.models.generation.interfaces import GenerationInterface
+from nemo_rl.models.policy import PolicyConfig
+from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface
+from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
+from nemo_rl.utils.logger import (
+ Logger,
+ LoggerConfig,
+ print_message_log_samples,
+)
+from nemo_rl.utils.nsys import maybe_gpu_profile_step
+from nemo_rl.utils.timer import TimeoutChecker, Timer
+
+# ============================================================================
+# Configuration
+# ============================================================================
+
+TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)
+
+_DEFAULT_REPROMPT_TEMPLATE = (
+ "{prompt}\n\n"
+ "Here is a correct solution for reference:\n\n"
+ "{solution}\n\n"
+ "Now solve the original problem."
+)
+
+
+class SDPOConfig(TypedDict):
+ """Top-level SDPO training configuration."""
+
+ num_prompts_per_step: int
+ num_generations_per_prompt: int
+ max_rollout_turns: int
+ max_num_epochs: int
+ max_num_steps: int
+ val_period: int
+ val_batch_size: int
+ val_at_start: bool
+ val_at_end: bool
+ max_val_samples: int
+ seed: int
+ # SDPO-specific
+ success_reward_threshold: float # reward >= this counts as "successful"
+ max_reprompt_len: int # max tokens in teacher prompt
+ reprompt_template: NotRequired[str] # uses {prompt} and {solution} placeholders
+ remove_thinking_from_demo: NotRequired[bool] # strip … from demos
+ dont_reprompt_on_self_success: NotRequired[bool] # exclude own success as demo
+
+
+class SDPOSaveState(TypedDict):
+ consumed_samples: int
+ current_step: int
+ current_epoch: int
+ total_steps: int
+ total_valid_tokens: int
+ val_reward: NotRequired[float]
+
+
+def _default_sdpo_save_state() -> SDPOSaveState:
+ return {
+ "consumed_samples": 0,
+ "current_step": 0,
+ "current_epoch": 0,
+ "total_steps": 0,
+ "total_valid_tokens": 0,
+ "val_reward": -99999999.0,
+ }
+
+
+class MasterConfig(TypedDict):
+ policy: PolicyConfig
+ loss_fn: SDPOLossConfig
+ env: dict[str, Any]
+ data: DataConfig
+ sdpo: SDPOConfig
+ logger: GRPOLoggerConfig
+ cluster: ClusterConfig
+ checkpointing: CheckpointingConfig
+
+
+# ============================================================================
+# Teacher-input construction
+# ============================================================================
+
+
+def _strip_thinking(text: str) -> str:
+ """Remove … blocks from a response string."""
+ return re.sub(r".*?", "", text, flags=re.DOTALL).strip()
+
+
+def build_sdpo_teacher_data(
+ message_logs: list,
+ rewards: torch.Tensor,
+ num_generations: int,
+ tokenizer: TokenizerType,
+ success_reward_threshold: float = 1.0,
+ reprompt_template: str = _DEFAULT_REPROMPT_TEMPLATE,
+ remove_thinking_from_demo: bool = False,
+ dont_reprompt_on_self_success: bool = False,
+ max_reprompt_len: int = 8192,
+ pad_token_id: int = 0,
+ make_seq_len_divisible_by: int = 1,
+) -> tuple[BatchedDataDict, torch.Tensor]:
+ """Build teacher-input sequences for SDPO self-distillation.
+
+ For each prompt group (N generations per prompt), finds the first
+ successful response (reward >= success_reward_threshold) and uses it
+ as the demonstration. Builds teacher sequences as:
+
+ [reprompted_prompt_tokens | original_response_tokens]
+
+ Returns:
+ teacher_data: BatchedDataDict with keys input_ids, input_lengths,
+ token_mask, sample_mask.
+ sdpo_mask: bool tensor of shape [B], True for samples that have a
+ demonstration and will receive a distillation signal.
+ """
+ B = len(message_logs)
+ P = B // num_generations
+
+ teacher_input_ids_list: list[torch.Tensor] = []
+ teacher_token_mask_list: list[torch.Tensor] = []
+ sdpo_mask = torch.zeros(B, dtype=torch.bool)
+
+ for p in range(P):
+ start = p * num_generations
+ end = start + num_generations
+ group_rewards = rewards[start:end]
+ group_logs = message_logs[start:end]
+
+ # Identify successful samples within this group
+ success_indices = [
+ i
+ for i in range(num_generations)
+ if group_rewards[i].item() >= success_reward_threshold
+ ]
+
+ # Per-sample: default = use original (no self-distillation signal)
+ for i in range(num_generations):
+ msg_log = group_logs[i]
+
+ # Extract original assistant response tokens (last assistant turn)
+ response_tokens: Optional[torch.Tensor] = None
+ for m in reversed(msg_log):
+ if m["role"] == "assistant":
+ response_tokens = m["token_ids"]
+ break
+
+ if response_tokens is None:
+ # No assistant message found — append empty; no signal
+ flat_tokens = torch.cat(
+ [m["token_ids"] for m in msg_log], dim=0
+ )
+ flat_mask = torch.cat(
+ [
+ torch.ones_like(m["token_ids"])
+ if m["role"] == "assistant"
+ else torch.zeros_like(m["token_ids"])
+ for m in msg_log
+ ],
+ dim=0,
+ )
+ teacher_input_ids_list.append(flat_tokens)
+ teacher_token_mask_list.append(flat_mask)
+ continue
+
+ # Find a suitable demonstration
+ demo_content: Optional[str] = None
+ if len(success_indices) > 0:
+ for demo_idx in success_indices:
+ if dont_reprompt_on_self_success and demo_idx == i:
+ continue
+ # Extract demo response content
+ for m in reversed(group_logs[demo_idx]):
+ if m["role"] == "assistant":
+ demo_content = m.get("content", "")
+ break
+ if demo_content is not None:
+ break
+
+ if demo_content is None:
+ # No valid demonstration — fall back to original sequence
+ flat_tokens = torch.cat(
+ [m["token_ids"] for m in msg_log], dim=0
+ )
+ flat_mask = torch.cat(
+ [
+ torch.ones_like(m["token_ids"])
+ if m["role"] == "assistant"
+ else torch.zeros_like(m["token_ids"])
+ for m in msg_log
+ ],
+ dim=0,
+ )
+ teacher_input_ids_list.append(flat_tokens)
+ teacher_token_mask_list.append(flat_mask)
+ continue
+
+ # We have a demonstration — build reprompted teacher input
+ sdpo_mask[start + i] = True
+
+ if remove_thinking_from_demo:
+ demo_content = _strip_thinking(demo_content)
+
+ # Build the reprompted prompt messages (keep all non-assistant turns,
+ # modify last user turn to include the demonstration)
+ teacher_messages = []
+ user_turn_count = 0
+ total_user_turns = sum(1 for m in msg_log if m["role"] == "user")
+ for m in msg_log:
+ if m["role"] == "assistant":
+ continue # skip — response is appended separately
+ if m["role"] == "user":
+ user_turn_count += 1
+ if user_turn_count == total_user_turns:
+ # Last user turn: inject demonstration
+ original_content = m.get("content", "")
+ reprompted_content = reprompt_template.format(
+ prompt=original_content,
+ solution=demo_content,
+ )
+ teacher_messages.append(
+ {"role": "user", "content": reprompted_content}
+ )
+ else:
+ teacher_messages.append(
+ {"role": "user", "content": m.get("content", "")}
+ )
+ else:
+ teacher_messages.append(
+ {"role": m["role"], "content": m.get("content", "")}
+ )
+
+ # Tokenize reprompted prompt
+ try:
+ teacher_prompt = tokenizer.apply_chat_template(
+ teacher_messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ return_dict=True,
+ truncation=True,
+ max_length=max_reprompt_len,
+ padding=False,
+ )
+ prompt_ids = teacher_prompt["input_ids"][0] # [prompt_len]
+ except Exception:
+ # Tokenisation failed — fall back to original
+ sdpo_mask[start + i] = False
+ flat_tokens = torch.cat(
+ [m["token_ids"] for m in msg_log], dim=0
+ )
+ flat_mask = torch.cat(
+ [
+ torch.ones_like(m["token_ids"])
+ if m["role"] == "assistant"
+ else torch.zeros_like(m["token_ids"])
+ for m in msg_log
+ ],
+ dim=0,
+ )
+ teacher_input_ids_list.append(flat_tokens)
+ teacher_token_mask_list.append(flat_mask)
+ continue
+
+ # Teacher sequence = reprompted prompt + original response
+ teacher_seq = torch.cat([prompt_ids.cpu(), response_tokens.cpu()])
+ prompt_mask = torch.zeros(len(prompt_ids), dtype=torch.long)
+ resp_mask = torch.ones(len(response_tokens), dtype=torch.long)
+ teacher_mask_seq = torch.cat([prompt_mask, resp_mask])
+
+ teacher_input_ids_list.append(teacher_seq)
+ teacher_token_mask_list.append(teacher_mask_seq)
+
+ # Pad all sequences to the same length
+ max_len = max(t.shape[0] for t in teacher_input_ids_list)
+ if make_seq_len_divisible_by > 1:
+ max_len = (
+ (max_len + make_seq_len_divisible_by - 1)
+ // make_seq_len_divisible_by
+ * make_seq_len_divisible_by
+ )
+
+ teacher_input_ids = torch.stack(
+ [
+ torch.nn.functional.pad(
+ t, (0, max_len - t.shape[0]), value=pad_token_id
+ )
+ for t in teacher_input_ids_list
+ ]
+ )
+ teacher_token_mask = torch.stack(
+ [
+ torch.nn.functional.pad(
+ m.long(), (0, max_len - m.shape[0]), value=0
+ )
+ for m in teacher_token_mask_list
+ ]
+ )
+ teacher_input_lengths = torch.tensor(
+ [t.shape[0] for t in teacher_input_ids_list], dtype=torch.long
+ )
+
+ teacher_data: BatchedDataDict = BatchedDataDict(
+ {
+ "input_ids": teacher_input_ids,
+ "input_lengths": teacher_input_lengths,
+ "token_mask": teacher_token_mask,
+ "sample_mask": torch.ones(B, dtype=torch.float32),
+ }
+ )
+
+ return teacher_data, sdpo_mask
+
+
+def align_teacher_logprobs(
+ teacher_logprobs: torch.Tensor, # [B, max_teacher_len]
+ teacher_token_mask: torch.Tensor, # [B, max_teacher_len] 1 = response token
+ student_seq_len: int, # target width
+ student_token_mask: torch.Tensor, # [B, student_seq_len] 1 = response token
+) -> torch.Tensor:
+ """Re-index teacher logprobs to student response positions.
+
+ The teacher sequence has a longer prompt (it includes the demonstration),
+ so response tokens sit at different positions. This function extracts the
+ teacher's response-token log-probs and places them at the corresponding
+ positions in the student sequence layout, returning a tensor of shape
+ [B, student_seq_len].
+ """
+ B = teacher_logprobs.shape[0]
+ aligned = torch.zeros(
+ B,
+ student_seq_len,
+ device=teacher_logprobs.device,
+ dtype=teacher_logprobs.dtype,
+ )
+
+ for i in range(B):
+ teacher_resp_pos = teacher_token_mask[i].bool().nonzero(as_tuple=True)[0]
+ student_resp_pos = student_token_mask[i].bool().nonzero(as_tuple=True)[0]
+
+ n = min(len(teacher_resp_pos), len(student_resp_pos))
+ if n == 0:
+ continue
+
+ # Both sequences contain the same response tokens — map positionally
+ aligned[i, student_resp_pos[:n]] = teacher_logprobs[
+ i, teacher_resp_pos[:n]
+ ]
+
+ return aligned
+
+
+# ============================================================================
+# Setup
+# ============================================================================
+
+
+def setup(
+ master_config: MasterConfig,
+ tokenizer: TokenizerType,
+ dataset: AllTaskProcessedDataset,
+ val_dataset: Optional[AllTaskProcessedDataset],
+) -> tuple:
+ """Set up SDPO training artefacts by delegating to grpo.setup.
+
+ Builds a minimal GRPO-compatible config from the SDPO config so that all
+ cluster/policy/generation initialisation is handled by the single source of
+ truth in grpo.setup. The GRPO loss function returned by grpo.setup is then
+ replaced with SDPOLossFn before returning.
+
+ Returns:
+ (policy, policy_generation, cluster, dataloader, val_dataloader,
+ loss_fn, logger, checkpointer, sdpo_save_state, master_config)
+ """
+ import copy
+ from nemo_rl.algorithms.grpo import setup as grpo_setup
+ from nemo_rl.algorithms.grpo import _default_grpo_save_state
+
+ sdpo_config = master_config["sdpo"]
+ loss_config = master_config["loss_fn"]
+
+ # Build a GRPO-compatible master config by mapping shared fields.
+ # grpo.setup only reads grpo_config for: num_prompts_per_step,
+ # num_generations_per_prompt, max_num_steps, max_num_epochs, seed,
+ # val_period, val_batch_size, val_at_start, val_at_end, max_val_samples,
+ # max_rollout_turns, and a handful of optional flags.
+ grpo_config_stub = {
+ "num_prompts_per_step": sdpo_config["num_prompts_per_step"],
+ "num_generations_per_prompt": sdpo_config["num_generations_per_prompt"],
+ "max_rollout_turns": sdpo_config["max_rollout_turns"],
+ "max_num_epochs": sdpo_config["max_num_epochs"],
+ "max_num_steps": sdpo_config["max_num_steps"],
+ "val_period": sdpo_config["val_period"],
+ "val_batch_size": sdpo_config["val_batch_size"],
+ "val_at_start": sdpo_config["val_at_start"],
+ "val_at_end": sdpo_config["val_at_end"],
+ "max_val_samples": sdpo_config["max_val_samples"],
+ "seed": sdpo_config["seed"],
+ # GRPO-specific flags that SDPO doesn't use — set safe defaults
+ "normalize_rewards": False,
+ "use_leave_one_out_baseline": False,
+ "use_dynamic_sampling": False,
+ "overlong_filtering": False,
+ "skip_reference_policy_logprobs_calculation": True,
+ "seq_logprob_error_threshold": None,
+ "reward_shaping": {"enabled": False},
+ "reward_scaling": {"enabled": False},
+ "adv_estimator": {"name": "grpo", "normalize_rewards": False,
+ "use_leave_one_out_baseline": False, "minus_baseline": False},
+ "async_grpo": {"enabled": False, "max_trajectory_age_steps": 1,
+ "in_flight_weight_updates": False,
+ "recompute_kv_cache_after_weight_updates": False},
+ "batch_multiplier": 1,
+ }
+
+ # GRPO loss_fn stub (grpo.setup uses it only to build ClippedPGLossFn,
+ # which we discard; the values here don't affect anything else in setup).
+ grpo_loss_fn_stub = {
+ "reference_policy_kl_penalty": 0.0,
+ "reference_policy_kl_type": "k3",
+ "kl_input_clamp_value": 20.0,
+ "kl_output_clamp_value": 10.0,
+ "ratio_clip_min": 0.2,
+ "ratio_clip_max": 0.2,
+ "ratio_clip_c": None,
+ "use_on_policy_kl_approximation": False,
+ "use_importance_sampling_correction": False,
+ "truncated_importance_sampling_ratio": None,
+ "truncated_importance_sampling_ratio_min": None,
+ "truncated_importance_sampling_type": "tis",
+ "sequence_level_importance_ratios": False,
+ "token_level_loss": True,
+ "force_on_policy_ratio": False,
+ "use_kl_in_reward": False,
+ }
+
+ grpo_master_config = copy.copy(master_config)
+ grpo_master_config["grpo"] = grpo_config_stub
+ grpo_master_config["loss_fn"] = grpo_loss_fn_stub
+
+ (
+ policy,
+ policy_generation,
+ cluster,
+ dataloader,
+ val_dataloader,
+ _grpo_loss_fn, # discard
+ logger,
+ checkpointer,
+ grpo_save_state,
+ _,
+ ) = grpo_setup(grpo_master_config, tokenizer, dataset, val_dataset)
+
+ # Replace GRPO loss with SDPO loss
+ loss_fn = SDPOLossFn(loss_config)
+
+ # Convert grpo save state to sdpo save state (same fields)
+ sdpo_save_state: SDPOSaveState = {
+ "consumed_samples": grpo_save_state["consumed_samples"],
+ "current_step": grpo_save_state["current_step"],
+ "current_epoch": grpo_save_state["current_epoch"],
+ "total_steps": grpo_save_state["total_steps"],
+ "total_valid_tokens": grpo_save_state["total_valid_tokens"],
+ "val_reward": grpo_save_state.get("val_reward", -99999999.0),
+ }
+
+ return (
+ policy,
+ policy_generation,
+ cluster,
+ dataloader,
+ val_dataloader,
+ loss_fn,
+ logger,
+ checkpointer,
+ sdpo_save_state,
+ master_config,
+ )
+
+
+# ============================================================================
+# Training loop
+# ============================================================================
+
+
+def sdpo_train(
+ policy: ColocatablePolicyInterface,
+ policy_generation: Optional[GenerationInterface],
+ dataloader: StatefulDataLoader,
+ val_dataloader: Optional[StatefulDataLoader],
+ tokenizer: TokenizerType,
+ loss_fn: LossFunction,
+ task_to_env: dict[str, EnvironmentInterface],
+ val_task_to_env: Optional[dict[str, EnvironmentInterface]],
+ logger: Logger,
+ checkpointer: CheckpointManager,
+ sdpo_save_state: SDPOSaveState,
+ master_config: MasterConfig,
+) -> None:
+ """Run SDPO training algorithm."""
+ timer = Timer()
+ timeout = TimeoutChecker(
+ timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
+ fit_last_save_time=True,
+ )
+ timeout.start_iterations()
+
+ NEED_REFIT = True
+ if policy_generation is None:
+ policy_generation = policy # type: ignore
+ NEED_REFIT = False
+ POLICY_GENERATION_STALE = True
+ assert policy_generation is not None
+
+ sdpo_cfg = master_config["sdpo"]
+ policy_cfg = master_config["policy"]
+
+ current_step = sdpo_save_state["current_step"]
+ total_steps = sdpo_save_state["total_steps"]
+ current_epoch = sdpo_save_state["current_epoch"]
+ max_num_steps = sdpo_cfg["max_num_steps"]
+ max_num_epochs = sdpo_cfg["max_num_epochs"]
+ consumed_samples = sdpo_save_state["consumed_samples"]
+ total_valid_tokens = sdpo_save_state.get("total_valid_tokens", 0)
+ val_at_start = sdpo_cfg["val_at_start"]
+ val_at_end = sdpo_cfg["val_at_end"]
+ val_period = sdpo_cfg["val_period"]
+ colocated_inference = policy_cfg["generation"]["colocated"]["enabled"]
+
+ num_generations = sdpo_cfg["num_generations_per_prompt"]
+ success_threshold = sdpo_cfg["success_reward_threshold"]
+ max_reprompt_len = sdpo_cfg["max_reprompt_len"]
+ reprompt_template = sdpo_cfg.get("reprompt_template", _DEFAULT_REPROMPT_TEMPLATE)
+ remove_thinking = sdpo_cfg.get("remove_thinking_from_demo", False)
+ dont_self = sdpo_cfg.get("dont_reprompt_on_self_success", False)
+ make_div_by = policy_cfg.get("make_sequence_length_divisible_by", 1)
+
+ # Run initial validation if requested
+ if val_at_start and current_step == 0:
+ print("\nRunning initial validation...", flush=True)
+ if NEED_REFIT and POLICY_GENERATION_STALE:
+ refit_policy_generation(policy, policy_generation, colocated_inference)
+ POLICY_GENERATION_STALE = False
+ else:
+ policy_generation.prepare_for_generation()
+ # grpo.validate() reads master_config["grpo"] for these keys; provide a shim.
+ _validate_cfg = {**master_config, "grpo": {
+ "max_val_samples": sdpo_cfg["max_val_samples"],
+ "val_batch_size": sdpo_cfg["val_batch_size"],
+ "max_rollout_turns": sdpo_cfg["max_rollout_turns"],
+ }}
+ val_metrics, _ = validate(
+ policy_generation,
+ val_dataloader,
+ tokenizer,
+ val_task_to_env,
+ total_steps,
+ _validate_cfg,
+ )
+ logger.log_metrics(val_metrics, step=total_steps)
+
+ while current_epoch < max_num_epochs and total_steps < max_num_steps:
+ print(
+ f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}"
+ )
+
+ for batch in dataloader:
+ metrics: dict[str, Any] = {}
+ metrics_logging_data: dict[str, Any] = {}
+
+ print(
+ f"\n{'=' * 25} Step {current_step + 1} {'=' * 25}",
+ flush=True,
+ )
+ maybe_gpu_profile_step(policy, total_steps + 1)
+ val_metrics = None
+
+ with timer.time("total_step_time"):
+ # ── Prepare batch ────────────────────────────────────────────
+ print("Preparing batch...", flush=True)
+ with timer.time("data_processing"):
+ repeated_batch: BatchedDataDict[DatumSpec] = (
+ batch.repeat_interleave(num_generations)
+ )
+ batched_flat, input_lengths = batched_message_log_to_flat_message(
+ repeated_batch["message_log"],
+ pad_value_dict={"token_ids": tokenizer.pad_token_id},
+ )
+ input_ids = batched_flat["token_ids"]
+
+ # ── Generate responses ───────────────────────────────────────
+ print(
+ f"Generating responses (batch size {repeated_batch.size})...",
+ flush=True,
+ )
+ with timer.time("prepare_for_generation/total"):
+ if NEED_REFIT and POLICY_GENERATION_STALE:
+ refit_policy_generation(
+ policy, policy_generation, colocated_inference
+ )
+ POLICY_GENERATION_STALE = False
+ else:
+ if colocated_inference:
+ policy.offload_after_refit()
+ policy_generation.prepare_for_generation()
+
+ with timer.time("generation"):
+ if _should_use_async_rollouts(master_config):
+ repeated_batch, rollout_metrics = (
+ run_async_multi_turn_rollout(
+ policy_generation=policy_generation,
+ input_batch=repeated_batch,
+ tokenizer=tokenizer,
+ task_to_env=task_to_env,
+ max_seq_len=policy_cfg[
+ "max_total_sequence_length"
+ ],
+ max_rollout_turns=sdpo_cfg["max_rollout_turns"],
+ greedy=False,
+ )
+ )
+ else:
+ repeated_batch, rollout_metrics = run_multi_turn_rollout(
+ policy_generation=policy_generation,
+ input_batch=repeated_batch,
+ tokenizer=tokenizer,
+ task_to_env=task_to_env,
+ max_seq_len=policy_cfg["max_total_sequence_length"],
+ max_rollout_turns=sdpo_cfg["max_rollout_turns"],
+ greedy=False,
+ )
+
+ metrics.update(rollout_metrics)
+
+ # ── Evaluate rewards ─────────────────────────────────────────
+ rewards = repeated_batch["total_reward"] # [B]
+
+ # ── Build training data (student) ────────────────────────────
+ with timer.time("data_processing"):
+ for i, message_log in enumerate(repeated_batch["message_log"]):
+ for j, message in enumerate(message_log):
+ if message["role"] == "assistant":
+ message["token_loss_mask"] = torch.ones_like(
+ message["token_ids"]
+ )
+ else:
+ message["token_loss_mask"] = torch.zeros_like(
+ message["token_ids"]
+ )
+ if "generation_logprobs" not in message:
+ message["generation_logprobs"] = torch.zeros_like(
+ message["token_ids"], dtype=torch.float32
+ )
+
+ flat_messages, input_lengths = batched_message_log_to_flat_message(
+ repeated_batch["message_log"],
+ pad_value_dict={"token_ids": tokenizer.pad_token_id},
+ make_sequence_length_divisible_by=make_div_by,
+ )
+
+ train_data = BatchedDataDict(
+ {
+ "input_ids": flat_messages["token_ids"],
+ "input_lengths": input_lengths,
+ "generation_logprobs": flat_messages[
+ "generation_logprobs"
+ ],
+ "token_mask": flat_messages["token_loss_mask"],
+ "sample_mask": torch.ones(
+ repeated_batch.size, dtype=torch.float32
+ ),
+ }
+ )
+ train_data.to("cpu")
+
+ # ── Compute student (prev) logprobs ──────────────────────────
+ print("Preparing for logprob inference...", flush=True)
+ with timer.time("logprob_inference_prep"):
+ policy.prepare_for_lp_inference()
+
+ print("Computing student logprobs...", flush=True)
+ with timer.time("student_logprobs"):
+ logprob_data = BatchedDataDict(
+ {
+ "input_ids": train_data["input_ids"],
+ "input_lengths": train_data["input_lengths"],
+ "token_mask": flat_messages["token_loss_mask"],
+ "sample_mask": train_data["sample_mask"],
+ }
+ )
+ train_data["prev_logprobs"] = policy.get_logprobs(
+ logprob_data, timer=timer
+ )["logprobs"]
+ del logprob_data
+
+ # ── Build teacher inputs ─────────────────────────────────────
+ print("Building SDPO teacher inputs...", flush=True)
+ with timer.time("teacher_input_construction"):
+ teacher_data, sdpo_mask = build_sdpo_teacher_data(
+ message_logs=repeated_batch["message_log"],
+ rewards=rewards,
+ num_generations=num_generations,
+ tokenizer=tokenizer,
+ success_reward_threshold=success_threshold,
+ reprompt_template=reprompt_template,
+ remove_thinking_from_demo=remove_thinking,
+ dont_reprompt_on_self_success=dont_self,
+ max_reprompt_len=max_reprompt_len,
+ pad_token_id=tokenizer.pad_token_id,
+ make_seq_len_divisible_by=make_div_by,
+ )
+ teacher_data.to("cpu")
+
+ frac_with_demo = sdpo_mask.float().mean().item()
+ metrics["sdpo/frac_with_demo_pre_train"] = frac_with_demo
+ print(
+ f" SDPO: {sdpo_mask.sum().item()}/{len(sdpo_mask)} samples "
+ f"have demonstrations ({100 * frac_with_demo:.1f}%)",
+ flush=True,
+ )
+
+ # ── Compute teacher logprobs ──────────────────────────────────
+ print("Computing teacher logprobs...", flush=True)
+ with timer.time("teacher_logprobs"):
+ teacher_lp_raw = policy.get_logprobs(
+ teacher_data, timer=timer
+ )["logprobs"] # [B, max_teacher_len]
+
+ # Align teacher logprobs to student sequence positions
+ with timer.time("teacher_logprob_alignment"):
+ student_seq_len = train_data["input_ids"].shape[1]
+ teacher_logprobs_aligned = align_teacher_logprobs(
+ teacher_logprobs=teacher_lp_raw.cpu(),
+ teacher_token_mask=teacher_data["token_mask"].cpu(),
+ student_seq_len=student_seq_len,
+ student_token_mask=train_data["token_mask"].cpu(),
+ )
+
+ train_data["teacher_logprobs"] = teacher_logprobs_aligned
+ train_data["sdpo_mask"] = sdpo_mask.float()
+ del teacher_data, teacher_lp_raw, teacher_logprobs_aligned
+
+ # ── Train ────────────────────────────────────────────────────
+ print("Preparing for training...", flush=True)
+ with timer.time("training_prep"):
+ policy.prepare_for_training()
+ POLICY_GENERATION_STALE = True
+
+ print("Training policy (SDPO)...", flush=True)
+ with timer.time("policy_training"):
+ train_results = policy.train(
+ train_data,
+ loss_fn,
+ timer=timer,
+ )
+
+ # ── Metrics & logging ────────────────────────────────────────
+ metrics["train/loss"] = train_results.get("loss", float("nan"))
+ metrics["train/grad_norm"] = train_results.get(
+ "grad_norm", float("nan")
+ )
+ metrics["train/mean_reward"] = rewards.mean().item()
+ metrics["train/success_fraction"] = (
+ (rewards >= success_threshold).float().mean().item()
+ )
+
+ # Aggregate SDPO-specific metrics from the loss function
+ for k, v in train_results.get("all_mb_metrics", {}).items():
+ if "sdpo" in k:
+ metrics[k] = (
+ sum(v) / len(v) if isinstance(v, list) else v
+ )
+
+ num_valid_tokens = int(
+ (
+ train_data["token_mask"]
+ * train_data["sample_mask"].unsqueeze(-1)
+ )
+ .sum()
+ .item()
+ )
+ total_valid_tokens += num_valid_tokens
+ metrics["train/num_valid_tokens"] = num_valid_tokens
+ metrics["train/total_valid_tokens"] = total_valid_tokens
+ consumed_samples += repeated_batch.size
+ metrics["train/consumed_samples"] = consumed_samples
+
+ # ── Validation ───────────────────────────────────────────────
+ is_last_step = (total_steps + 1 >= max_num_steps) or (
+ (current_epoch + 1 == max_num_epochs)
+ and (current_step + 1 == len(dataloader))
+ )
+
+ if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
+ val_at_end and is_last_step
+ ):
+ if NEED_REFIT and POLICY_GENERATION_STALE:
+ refit_policy_generation(
+ policy, policy_generation, colocated_inference
+ )
+ POLICY_GENERATION_STALE = False
+ # grpo.validate() reads master_config["grpo"] for these keys; provide a shim.
+ _validate_cfg = {**master_config, "grpo": {
+ "max_val_samples": sdpo_cfg["max_val_samples"],
+ "val_batch_size": sdpo_cfg["val_batch_size"],
+ "max_rollout_turns": sdpo_cfg["max_rollout_turns"],
+ }}
+ val_metrics, _ = validate(
+ policy_generation,
+ val_dataloader,
+ tokenizer,
+ val_task_to_env,
+ total_steps,
+ _validate_cfg,
+ )
+ metrics.update(val_metrics)
+
+ # Update best val reward in save state
+ val_reward_key = "val:accuracy"
+ if val_reward_key in val_metrics:
+ sdpo_save_state["val_reward"] = max(
+ sdpo_save_state.get("val_reward", -1e9),
+ val_metrics[val_reward_key],
+ )
+
+ # Log
+ logger.log_metrics(metrics, step=total_steps + 1)
+
+ # ── Checkpoint ───────────────────────────────────────────────
+ current_step += 1
+ total_steps += 1
+ sdpo_save_state.update(
+ {
+ "current_step": current_step,
+ "total_steps": total_steps,
+ "current_epoch": current_epoch,
+ "consumed_samples": consumed_samples,
+ "total_valid_tokens": total_valid_tokens,
+ }
+ )
+
+ timeout.mark_iteration()
+ should_save_by_step = (
+ is_last_step
+ or total_steps % master_config["checkpointing"]["save_period"] == 0
+ )
+ should_save_by_timeout = timeout.check_save()
+
+ if master_config["checkpointing"]["enabled"] and (
+ should_save_by_step or should_save_by_timeout
+ ):
+ # Track metric for top-k checkpointing
+ full_metric_name = master_config["checkpointing"]["metric_name"]
+ if ":" in full_metric_name:
+ prefix, metric_name = full_metric_name.split(":", 1)
+ metrics_source = metrics if prefix == "train" else val_metrics
+ if metrics_source and metric_name in metrics_source:
+ sdpo_save_state[full_metric_name] = metrics_source[
+ metric_name
+ ]
+
+ print(
+ f"Saving checkpoint for step {total_steps}...", flush=True
+ )
+ checkpoint_path = checkpointer.init_tmp_checkpoint(
+ total_steps, sdpo_save_state, master_config
+ )
+ policy.save_checkpoint(
+ weights_path=os.path.join(
+ checkpoint_path, "policy", "weights"
+ ),
+ optimizer_path=os.path.join(
+ checkpoint_path, "policy", "optimizer"
+ )
+ if checkpointer.save_optimizer
+ else None,
+ tokenizer_path=os.path.join(
+ checkpoint_path, "policy", "tokenizer"
+ ),
+ checkpointing_cfg=master_config["checkpointing"],
+ )
+ torch.save(
+ dataloader.state_dict(),
+ os.path.join(checkpoint_path, "train_dataloader.pt"),
+ )
+ checkpointer.finalize_checkpoint(checkpoint_path)
+
+ if should_save_by_timeout:
+ print("Timeout reached, stopping training early.", flush=True)
+ return
+
+ current_epoch += 1
+ current_step = 0
+ sdpo_save_state["current_epoch"] = current_epoch
+ sdpo_save_state["current_step"] = 0
+
+ print("SDPO training complete.", flush=True)