[NNX] Implement pure-NNX path for post-training correctness tests#4186
Open
ecnal-cienet wants to merge 1 commit into
Open
[NNX] Implement pure-NNX path for post-training correctness tests#4186ecnal-cienet wants to merge 1 commit into
ecnal-cienet wants to merge 1 commit into
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
9704409 to
8be7bd5
Compare
The post-training correctness / golden-generation tests raised NotImplementedError under pure_nnx=True (the default since PR11). Wire up the NNX path as a dispatch that keeps the Linen path working, since Linen stays available for now: - NNX builds the model with model_creation_utils.from_pretrained (loads a checkpoint when load_parameters_path is set, otherwise inits) and a frozen reference via nnx.clone. Forward is the native model(...) call, per-token log-probs use compute_log_probs_nnx, and the GRPO loss uses grpo_loss_fn_nnx. - The Linen branches are unchanged (transformer_as_linen + state.params + compute_log_probs / grpo_loss_fn / _merge_grpo_state). Also export from_config from the maxtext package: grpo_trainer references mt.from_config, which previously did not resolve. Add a CPU unit test (tests/unit/correctness_tests_nnx_dispatch_test.py) covering the from_config export and the SFT dispatch on both paths, since the correctness tests themselves are TPU-only and skipped (b/425997645).
8be7bd5 to
7784694
Compare
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
1 similar comment
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
NuojCheng
approved these changes
Jun 18, 2026
igorts-git
approved these changes
Jun 18, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Implement the pure-NNX path for the post-training correctness / golden-generation tests, which raised
NotImplementedErrorunderpure_nnx=True(the default since PR #3526). Done as a dispatch that keeps the Linen path working (Linen stays available for a while).NNX branch per site:
model_creation_utils.from_pretrained(config, mesh, rng_key)(loads a checkpoint whenload_parameters_pathis set, else inits); frozen reference viannx.clone.model(...); per-token log-probs =compute_log_probs_nnx; GRPO loss =grpo_loss_fn_nnx.transformer_as_linen+state.params+compute_log_probs/grpo_loss_fn/_merge_grpo_state).Also exports
from_configfrom themaxtextpackage —grpo_trainercallsmt.from_config, which previously did not resolve.Tests
CPU unit test —
tests/unit/correctness_tests_nnx_dispatch_test.pycovers the changed code that has no other CPU coverage: themt.from_configexport and the SFT correctness test's ownsetup_maxtext_model/get_maxtext_logitson both paths (pure_nnx=True/False). The GRPO NNX building blocks (compute_log_probs_nnx,grpo_loss_fn_nnx) are already covered bygrpo_nnx_test.py.JAX_PLATFORMS=cpu python3 -m pytest tests/unit/correctness_tests_nnx_dispatch_test.py -v # 3 passedFull tests (TPU host + HF auth + golden data): they're
@pytest.mark.skip(b/425997645) +@pytest.mark.tpu_only, so the skip must be lifted first.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.