Skip to content

[NNX] Implement pure-NNX path for post-training correctness tests#4186

Open
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-correctness-tests
Open

[NNX] Implement pure-NNX path for post-training correctness tests#4186
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-correctness-tests

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Description

Implement the pure-NNX path for the post-training correctness / golden-generation tests, which raised NotImplementedError under pure_nnx=True (the default since PR #3526). Done as a dispatch that keeps the Linen path working (Linen stays available for a while).

NNX branch per site:

  • Model built with model_creation_utils.from_pretrained(config, mesh, rng_key) (loads a checkpoint when load_parameters_path is set, else inits); frozen reference via nnx.clone.
  • Forward = native model(...); per-token log-probs = compute_log_probs_nnx; GRPO loss = grpo_loss_fn_nnx.
  • Linen branches unchanged (transformer_as_linen + state.params + compute_log_probs / grpo_loss_fn / _merge_grpo_state).

Also exports from_config from the maxtext package — grpo_trainer calls mt.from_config, which previously did not resolve.

Tests

CPU unit testtests/unit/correctness_tests_nnx_dispatch_test.py covers the changed code that has no other CPU coverage: the mt.from_config export and the SFT correctness test's own setup_maxtext_model/get_maxtext_logits on both paths (pure_nnx=True/False). The GRPO NNX building blocks (compute_log_probs_nnx, grpo_loss_fn_nnx) are already covered by grpo_nnx_test.py.

JAX_PLATFORMS=cpu python3 -m pytest tests/unit/correctness_tests_nnx_dispatch_test.py -v
# 3 passed

Full tests (TPU host + HF auth + golden data): they're @pytest.mark.skip (b/425997645) + @pytest.mark.tpu_only, so the skip must be lifted first.

export PYTHONPATH=$PWD/src HF_TOKEN=<token>

# SFT — comment the @pytest.mark.skip line, then run on TPU:
sed -i '/@pytest.mark.skip(reason="Logit output test fragile/s/^/# /' \
  tests/post_training/integration/sft_trainer_correctness_test.py
python3 -m pytest -v -s \
  tests/post_training/integration/sft_trainer_correctness_test.py::SFTTrainerCorrectnessTest::test_sft_trainer_correctness
git checkout tests/post_training/integration/sft_trainer_correctness_test.py   # restore the skip

# GRPO trainer correctness — same skip-lift pattern:
sed -i '/@pytest.mark.skip(reason="Logit output test fragile/s/^/# /' \
  tests/post_training/integration/grpo_trainer_correctness_test.py
python3 -m pytest -v -s \
  tests/post_training/integration/grpo_trainer_correctness_test.py::GRPOTrainerCorrectnessTest::test_grpo_trainer_correctness
git checkout tests/post_training/integration/grpo_trainer_correctness_test.py

# GRPO correctness — torch/TRL, no method skip (use the post-train venv):
python3 -m pytest -v -s tests/post_training/integration/grpo_correctness.py

# Golden generator (script, not a pytest test):
python3 -m tests.assets.logits_generation.generate_grpo_golden_logits

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 17, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-correctness-tests branch from 9704409 to 8be7bd5 Compare June 17, 2026 17:25
The post-training correctness / golden-generation tests raised
NotImplementedError under pure_nnx=True (the default since PR11). Wire up
the NNX path as a dispatch that keeps the Linen path working, since Linen
stays available for now:

- NNX builds the model with model_creation_utils.from_pretrained (loads a
  checkpoint when load_parameters_path is set, otherwise inits) and a
  frozen reference via nnx.clone. Forward is the native model(...) call,
  per-token log-probs use compute_log_probs_nnx, and the GRPO loss uses
  grpo_loss_fn_nnx.
- The Linen branches are unchanged (transformer_as_linen + state.params +
  compute_log_probs / grpo_loss_fn / _merge_grpo_state).

Also export from_config from the maxtext package: grpo_trainer references
mt.from_config, which previously did not resolve.

Add a CPU unit test (tests/unit/correctness_tests_nnx_dispatch_test.py)
covering the from_config export and the SFT dispatch on both paths, since
the correctness tests themselves are TPU-only and skipped (b/425997645).
@github-actions

Copy link
Copy Markdown

🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

1 similar comment
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants