Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 294 additions & 0 deletions examples/configs/sdpo_gsm8k_1B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# SDPO on GSM8K with Qwen2.5-1.5B-Instruct
#
# Self-Distilled Policy Optimization (SDPO):
# arXiv:2601.20802, Hübotter et al. 2026
#
# For each prompt group the model's own high-reward rollouts serve as
# "demonstrations". The loss is a token-level reverse-KL between the
# policy conditioned on the demonstration (teacher) and the policy without
# the demonstration (student).

sdpo:
num_prompts_per_step: 16 # number of distinct prompts per step
num_generations_per_prompt: 8 # generations per prompt (need > 1 for demos)
max_rollout_turns: 1
max_num_epochs: 1
max_num_steps: 1000000
val_period: 10
val_at_start: false
val_at_end: false
max_val_samples: 256
val_batch_size: 256
seed: 42
# SDPO-specific
success_reward_threshold: 1.0 # reward must equal 1.0 to count as "successful"
max_reprompt_len: 4096 # max tokens in the reprompted teacher prompt
remove_thinking_from_demo: false # strip <think>…</think> from demonstrations
dont_reprompt_on_self_success: true # don't use own success as demo

loss_fn:
success_reward_threshold: 1.0 # kept in sync with sdpo.success_reward_threshold
is_clip: 2.0 # clip IS ratio; null to disable

checkpointing:
enabled: true
checkpoint_dir: "results/sdpo_gsm8k"
metric_name: "val:accuracy"
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
model_save_format: "safetensors"
save_consolidated: false
save_optimizer: true

policy:
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer:
name: ${policy.model_name}
chat_template_kwargs: null
hf_config_overrides: {}
train_global_batch_size: 128 # 16 prompts × 8 gens
train_micro_batch_size: 4
generation_batch_size: 32
logprob_batch_size: ${policy.train_micro_batch_size}
max_total_sequence_length: 1024
precision: "bfloat16"
logprob_chunk_size: null
offload_optimizer_for_logprob: false

dtensor_cfg:
_v2: true
enabled: true
cpu_offload: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
automodel_kwargs: {}
lora_cfg:
enabled: False
target_modules: []
exclude_modules: []
match_all_linear: true
dim: 8
alpha: 32
dropout: 0.0
dropout_position: "post"
lora_A_init: "xavier"
use_triton: true

megatron_cfg:
enabled: false
force_reconvert_from_hf: False
empty_unused_memory_level: 1
activation_checkpointing: false
converter_type: "Qwen2ForCausalLM"
tensor_model_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
pipeline_model_parallel_size: 1
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
context_parallel_size: 1
pipeline_dtype: ${policy.precision}
sequence_parallel: false
freeze_moe_router: true
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none"
moe_router_bias_update_rate: 0.0
moe_permute_fusion: false
apply_rope_fusion: True
bias_activation_fusion: True
defer_fp32_logits: False
moe_per_layer_logging: False
moe_enable_deepep: false
moe_token_dispatcher_type: "alltoall"
moe_shared_expert_overlap: false

peft:
enabled: false
target_modules: []
exclude_modules: []
dim: 8
alpha: 32
dropout: 0.0
dropout_position: "post"
lora_A_init_method: "xavier"
lora_B_init_method: "zero"
a2a_experimental: false
lora_dtype: None

optimizer:
optimizer: "adam"
lr: 1.0e-5
min_lr: 1.0e-6
weight_decay: 0.01
bf16: true
fp16: false
params_dtype: "float32"
adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8
sgd_momentum: 0.9
use_distributed_optimizer: true
use_precision_aware_optimizer: true
clip_grad: ${policy.max_grad_norm}
optimizer_cpu_offload: false
optimizer_offload_fraction: 0.0

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: 1000
lr_warmup_iters: 10
lr_warmup_init: 1.0e-6

distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"

fp8_cfg:
enabled: false
fp8: "e4m3"
fp8_recipe: "blockwise"
fp8_param: false

env_vars: null

draft:
enabled: false
model_name: null
loss_weight: 0.1
num_layers: null
aux_layer_indices: null

dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 1.0e-5
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
foreach: False
fused: False

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.1
end_factor: 1.0
total_iters: 10
- name: "torch.optim.lr_scheduler.ConstantLR"
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones: [10]

generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
mcore_generation_config:
buffer_size_gb: 10
num_cuda_graphs: 4
block_size_tokens: 256
use_cuda_graphs_for_non_decode_steps: true
enable_chunked_prefill: true
unified_memory_level: 0
max_tokens: 16384
vllm_cfg:
async_engine: false
precision: ${policy.precision}
kv_cache_dtype: "auto"
tensor_parallel_size: 1
pipeline_parallel_size: 1
expert_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: False
use_deep_gemm: False
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
enable_vllm_metrics_logger: true
vllm_metrics_logger_interval: 0.5
vllm_kwargs: {}
colocated:
enabled: true
resources:
gpus_per_node: null
num_nodes: null

data:
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true
num_workers: 1
use_multiple_dataloader: false
train:
dataset_name: gsm8k
validation:
dataset_name: gsm8k
split: "test"
default:
prompt_file: "examples/prompts/gsm8k.txt"
system_prompt_file: null
processor: "math_hf_data_processor"
env_name: "math"

env:
math:
num_workers: 8
math_verify_impl: "hf_math_verify"

logger:
log_dir: "logs/sdpo_gsm8k"
num_val_samples_to_print: 0
wandb_enabled: false
tensorboard_enabled: false
mlflow_enabled: false
swanlab_enabled: false
monitor_gpus: true
wandb:
project: "sdpo-dev"
name: "sdpo-gsm8k-1B"
swanlab:
project: "sdpo-dev"
name: "sdpo-gsm8k-1B"
tensorboard: {}
mlflow:
experiment_name: "sdpo-dev"
run_name: "sdpo-gsm8k-1B"
tracking_uri: "http://localhost:5000"
gpu_monitoring:
collection_interval: 10
flush_interval: 10

cluster:
gpus_per_node: 1
num_nodes: 1
Loading
Loading