Skip to content

feat: decouple IS correction level from GSPO policy ratio#2269

Open
bo-ke wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
bo-ke:feat/decouple-is-correction-level
Open

feat: decouple IS correction level from GSPO policy ratio#2269
bo-ke wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
bo-ke:feat/decouple-is-correction-level

Conversation

@bo-ke
Copy link
Copy Markdown

@bo-ke bo-ke commented Apr 15, 2026

What does this PR do?

Decouple IS correction granularity from the GSPO policy ratio so they can be
configured independently.

Currently sequence_level_importance_ratios=True bundles two behaviors into a
single flag:

  1. GSPO policy ratio: exp(mean(log(π_curr/π_prev))) — numerically stable
  2. IS correction weight: exp(sum(prev - gen)) — collapses for long
    sequences

For long sequences (~9k+ tokens), even a small per-token logprob mismatch between
training and inference backends (e.g., ~2% with vLLM + Megatron) accumulates via
exp(sum) to near-zero IS weights, effectively reducing the learning rate by
10–50x with high variance across steps.

This PR adds a new is_correction_level config field to ClippedPGLossConfig
that controls IS correction weight computation independently of the PPO ratio
(which remains controlled by sequence_level_importance_ratios).

Supported values:

  • "sequence"exp(sum(prev - gen)), one weight per sequence (current
    behavior)
  • "sequence_mean"exp(mean(prev - gen)), geometric mean, numerically stable
    for long sequences
  • "token"exp(prev - gen) per token, token-level correction

When unset (None), behavior is derived from sequence_level_importance_ratios
for full backward compatibility.

Issues

None — discovered during GRPO training with Qwen3.5-35B-A3B (MoE, ~9k avg
response tokens) where sampling_importance_ratio showed mean ~0.4, std ~0.27,
with some steps collapsing to 0.02.

Usage

loss_fn:
  # GSPO policy ratio (unchanged)                                                
  sequence_level_importance_ratios: true
  token_level_loss: false                                                        
  # IS correction now independently configurable                               
  is_correction_level: "sequence_mean"  # or "token", or "sequence" (default)    
  use_importance_sampling_correction: true                                       
  truncated_importance_sampling_ratio: 2                                         
  truncated_importance_sampling_type: tis                                        

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed /NVIDIA-NeMo/RL/blob/main/CONTRIBUTING.md
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our
    /NVIDIA-NeMo/RL/blob/main/docs/testing.md for how to run tests
  • Did you add or update any necessary documentation? Visit our
    /NVIDIA-NeMo/RL/blob/main/docs/documentation.md for how to write, build and test
    the docs.

Additional Information

  • All 44 existing tests pass (backward compat verified)
  • New test test_clipped_pg_loss_gspo_is_correction_level_sequence_mean validates
    that exp(mean) produces IS weights closer to 1.0 than exp(sum) and that the loss
    value matches hand-calculated expectations
  • seq-mask-tis compatibility assertion updated to account for the new field

@bo-ke bo-ke requested review from a team as code owners April 15, 2026 03:42
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Add `is_correction_level` config field to `ClippedPGLossConfig` that
controls IS correction weight computation independently of the PPO ratio
(which is still controlled by `sequence_level_importance_ratios`).

Currently `sequence_level_importance_ratios=True` bundles two behaviors:
1. GSPO policy ratio: exp(mean(log(π_curr/π_prev))) — stable
2. IS correction: exp(sum(prev - gen)) — collapses for long sequences

For long sequences (~9k+ tokens), even a small per-token logprob
mismatch between training and inference backends accumulates via exp(sum)
to near-zero IS weights, effectively reducing the learning rate by 10-50x
with high variance across steps.

The new `is_correction_level` field supports:
- "sequence": exp(sum) — current behavior
- "sequence_mean": exp(mean) — geometric mean, numerically stable
- "token": per-token exp(prev - gen)

When unset (None), behavior is derived from
`sequence_level_importance_ratios` for full backward compatibility.

Signed-off-by: kebo01 <kebo01@baidu.com>
@bo-ke bo-ke force-pushed the feat/decouple-is-correction-level branch from 9931c9e to c7b1b01 Compare April 15, 2026 03:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants