Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
86e90ad
Add off-policy distillation with MATH/MMLU eval and IPC optimization
Feb 27, 2026
f1e9bb9
Commit before refactoring
Mar 4, 2026
46fba91
Simplify off-policy distillation IPC path and config
Mar 4, 2026
0831068
Working IPC TP=1
Mar 9, 2026
ee58833
Per-microbatch IPC teacher logits with TP=4 support
Mar 11, 2026
b7481d9
Clean up unused scripts and old distillation module
Mar 11, 2026
dba34ca
Add IPC/non-IPC toggle for off-policy distillation
Mar 12, 2026
0c94f0f
Integrate cross-tokenizer distillation (TokenAligner) into NeMo RL
Mar 16, 2026
ad2223a
Add cross-tokenizer off-policy distillation README for code review
Mar 18, 2026
0662b63
Add gold loss, Phi-4 teacher support, MMLU few-shot eval, and Py3.10 …
Mar 18, 2026
92d00a3
made teacher inference efficient, not passing loss_fn
Mar 31, 2026
e37128f
Cache loss_fn on workers to avoid Ray re-serialization and optimize s…
Apr 1, 2026
fb0dc02
Build cross-tokenizer loss fn locally on workers and shard data per D…
Apr 1, 2026
f77a4c7
Add char-offset alignment method and DP fallback
Apr 7, 2026
11d0361
Add CUDA-backed token alignment DP implementation.
Apr 8, 2026
5f80087
Refactor off-policy cross-tokenizer processing to pipeline next-batch…
Apr 8, 2026
f9a6346
Tune cross-tokenizer worker configuration for large-scale runs.
Apr 8, 2026
2ca1b94
Rebased code based on main branch.
Apr 12, 2026
ec759fc
Rebased code based on main branch.
Apr 12, 2026
c48612d
Align off-policy distillation with NeMo-RL conventions and relocate x…
Apr 13, 2026
d8c2e28
Refactor DTensor off-policy distillation IPC flow and stabilize xtoke…
Apr 13, 2026
1d5af64
fix(off-policy-distillation): normalize non-IPC teacher top-k values …
Apr 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
429 changes: 429 additions & 0 deletions CROSS_TOKENIZER_README.md

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions eval_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Model,top_k,MATH (%),MATH (correct),MATH (total),MMLU (%),MMLU (correct),MMLU (total),wandb
Llama-3.1-8B (teacher),--,12.56,628,5000,35.14,4935,14042,
Llama-3.2-1B (pretrained baseline),--,5.64,282,5000,23.06,3238,14042,
Distillation forward KL (50 steps),64,5.52,276,5000,23.25,3265,14042,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/uc8nnmh1
Distillation forward KL (1000 steps),64,5.84,292,5000,26.24,3684,14042,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/uc8nnmh1
Delta from baseline (top_k=64),,+0.20,+10,,+3.18,+446,,
Distillation forward KL (50 steps),4096,5.27,27,512,26.56,136,512,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/8a7dictz
Distillation forward KL (1000 steps),4096,5.27,27,512,26.56,136,512,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/8a7dictz
SFT (50 steps),--,4.06,203,5000,16.79,2358,14042,https://wandb.ai/nvidia/nemo-sft-arrow-eval/runs/w5fiqlbw
SFT (1000 steps),--,7.30,365,5000,22.13,3108,14042,https://wandb.ai/nvidia/nemo-sft-arrow-eval/runs/w5fiqlbw
Delta from baseline (SFT),,+1.66,+83,,-0.93,-130,,
267 changes: 267 additions & 0 deletions examples/configs/cross_tokenizer_off_policy_arrow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Cross-Tokenizer Off-Policy Distillation Configuration
# Student: Llama-3.2-1B (128K vocab) <- Teacher: Phi-4-mini-instruct
#
# Requires a precomputed projection matrix. Generate with:
# python nemo_rl/utils/x_token/minimal_projection_via_multitoken.py \
# --student-model meta-llama/Llama-3.2-1B \
# --teacher-model microsoft/Phi-4-mini-instruct
#
# Then optionally enforce exact matches:
# python nemo_rl/utils/x_token/reapply_exact_map.py \
# --student-model meta-llama/Llama-3.2-1B \
# --teacher-model microsoft/Phi-4-mini-instruct \
# --initial-projection-path cross_tokenizer_data/transformation_counts_via_multitoken.pt

token_aligner:
enabled: true
projection_matrix_path: "cross_tokenizer_data/projection_map_Llama-3.2_to_Phi-4-mini-instruct_multitoken_top_32_double_special.pt"
use_sparse_format: false
loss_type: "KL"
exact_token_match_only: false
temperature: 1.0
vocab_topk: 8192
reverse_kl: false
projection_matrix_multiplier: 1.0
max_comb_len: 4
learnable: false
project_teacher_to_student: false # Remove this
use_char_offset: false
use_align_fast: true
use_cuda_dp: false
dp_chunk_size: 128

distillation:
num_prompts_per_step: 768
num_generations_per_prompt: 1
max_num_steps: 80000
max_num_epochs: 1
val_period: 1000
val_at_start: false
max_val_samples: 128
val_batch_size: 64
topk_logits_k: 8192
use_ipc: true
loss_on_all_tokens: true
seed: 42
# Number of CPU processes for cross-tokenizer decode/encode/align.
# Heuristic currently used: total GPUs / 2 (16 nodes * 8 GPUs = 128 -> 64 workers).
# Note: this is workload-dependent; tune per batch size and cluster shape.
cross_tokenizer_num_workers: 64

loss_fn:
loss_type: "KL"
temperature: 1.0
vocab_topk: 8192
exact_token_match_only: false
reverse_kl: false
project_teacher_to_student: false
gold_loss: true
xtoken_loss: true
ce_loss_scale: 0.1
dynamic_loss_scaling: true

checkpointing:
enabled: true
checkpoint_dir: "checkpoints/cross-tokenizer-distillation-llama1b-phi4mini-instruct"
metric_name: "train:loss"
higher_is_better: false
keep_top_k: 3
save_period: 10
save_optimizer: true
model_save_format: "safetensors"
save_consolidated: false

policy:
model_name: "meta-llama/Llama-3.2-1B"
tokenizer:
name: "meta-llama/Llama-3.2-1B"
chat_template: null
train_global_batch_size: 768
train_micro_batch_size: 1
max_total_sequence_length: 4096
precision: "bfloat16"

dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.02
end_factor: 1.0
total_iters: 4000
- name: "torch.optim.lr_scheduler.CosineAnnealingLR"
kwargs:
T_max: 76000
eta_min: 0.0
- milestones: [4000]

generation:
backend: "vllm"
max_new_tokens: 2048
temperature: 0.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false
precision: "bfloat16"
kv_cache_dtype: "auto"
tensor_parallel_size: 1
pipeline_parallel_size: 1
expert_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: 2048
enforce_eager: false
use_deep_gemm: false
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
distributed_executor_backend: null
colocated:
enabled: true
resources:
gpus_per_node: null
num_nodes: null

sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64

teacher:
model_name: "microsoft/Phi-4-mini-instruct"
tokenizer:
name: "microsoft/Phi-4-mini-instruct"
chat_template: null
precision: "bfloat16"
train_global_batch_size: 768
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
max_grad_norm: 1.0
logprob_chunk_size: null
offload_optimizer_for_logprob: false

dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64

sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false

generation: null

data:
max_input_seq_length: 4096
shuffle: true
train:
dataset_name: "arrow_text"
# Preferred local path input for large-scale jobs (set via CLI override or submit script).
arrow_files: null
prompt_file: null
default:
# Fallback dataset used when train.arrow_files is not provided.
dataset_path: "allenai/c4"
hf_dataset_name: "allenai/c4"
hf_dataset_subset: "en"
hf_split: "train"
text_key: "text"

eval:
val_period: 1000
val_at_start: false
max_val_samples: 512
val_batch_size: 64
max_rollout_turns: 1
benchmarks:
math:
dataset_name: "math"
prompt_file: "examples/prompts/cot.txt"
env:
num_workers: 8
mmlu:
dataset_name: "mmlu"
prompt_file: "examples/prompts/mmlu.txt"
env:
num_workers: 8
verifier_type: "multilingual_multichoice"
mmlu_5shot:
dataset_name: "mmlu"
prompt_file: "examples/prompts/mmlu.txt"
num_few_shot: 5
env:
num_workers: 8
verifier_type: "multilingual_multichoice"

logger:
log_dir: "logs/cross-tokenizer-distillation-llama1b-phi4mini-instruct"
num_val_samples_to_print: 5
wandb_enabled: true
swanlab_enabled: false
mlflow_enabled: false
tensorboard_enabled: false
monitor_gpus: true
wandb:
project: "nemo-cross-tokenizer-distillation"
name: "cross-tokenizer-llama1b-phi4mini-instruct-bs768"
gpu_monitoring:
collection_interval: 10
flush_interval: 10

cluster:
gpus_per_node: 8
num_nodes: 16
2 changes: 1 addition & 1 deletion examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ teacher:
dtensor_cfg:
<<: *DTENSOR_BASE
context_parallel_size: 2
tensor_parallel_size: 4
tensor_parallel_size: 2

data:
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
Expand Down
17 changes: 17 additions & 0 deletions examples/configs/evals/llama_math_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Math evaluation for Llama 3.2 1B (base model)
# Override generation.model_name with checkpoint path at runtime:
# generation.model_name=/path/to/checkpoint/policy/weights
defaults: "eval.yaml"

generation:
model_name: "meta-llama/Llama-3.2-1B"
vllm_cfg:
max_model_len: 2048

tokenizer:
name: ${generation.model_name}
chat_template: null # base model, no chat formatting

data:
prompt_file: "examples/prompts/cot.txt"
dataset_name: "math"
21 changes: 21 additions & 0 deletions examples/configs/evals/llama_mmlu_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MMLU evaluation for Llama 3.2 1B (base model)
# Override generation.model_name with checkpoint path at runtime:
# generation.model_name=/path/to/checkpoint/policy/weights
defaults: "eval.yaml"

generation:
model_name: "meta-llama/Llama-3.2-1B"
vllm_cfg:
max_model_len: 2048

tokenizer:
name: ${generation.model_name}
chat_template: null # base model, no chat formatting

data:
prompt_file: "examples/prompts/mmlu.txt"
dataset_name: "mmlu"

env:
math:
verifier_type: "multilingual_multichoice"
Loading
Loading