Skip to content

Fix orthogonal parameterization save/load for training continuation#179

Open
aryamanarora wants to merge 1 commit intomainfrom
fix/orthogonal-save-load
Open

Fix orthogonal parameterization save/load for training continuation#179
aryamanarora wants to merge 1 commit intomainfrom
fix/orthogonal-save-load

Conversation

@aryamanarora
Copy link
Copy Markdown
Collaborator

Summary

  • Fix orthogonal parameterization save/load to preserve internal state for proper training continuation
  • Add optional debug flag for metrics logging (off by default)

Problem

The orthogonal parameterization in LoreftIntervention uses PyTorch's torch.nn.utils.parametrizations.orthogonal, which stores an internal "original" tensor and computes the orthogonal weight via Cayley/Householder transform.

Previously, state_dict() only saved the computed orthogonal weight, not the internal parametrization state. When loading and continuing training:

  1. The orthogonal weight was written to .base
  2. But .original was freshly initialized
  3. After the first optimizer step, orthogonality broke (error jumped from ~1e-7 to ~0.1)

This caused loss spikes when continuing training from checkpoints.

Solution

  • state_dict() now saves rotate_layer_original and rotate_layer_base
  • load_state_dict() restores full parametrization state when available
  • Backwards compatible: legacy checkpoints (without new keys) still work for inference

Test plan

  • Verified orthogonality is preserved through save/load/continue cycle
  • Verified legacy checkpoints still load correctly for inference
  • Tested with LoreftIntervention directly

🤖 Generated with Claude Code

@stanfordnlp stanfordnlp deleted a comment from chatgpt-codex-connector bot Jan 12, 2026
@aryamanarora
Copy link
Copy Markdown
Collaborator Author

@codex review

@aryamanarora
Copy link
Copy Markdown
Collaborator Author

looks like codex is not working. have to call my other agent.

@frankaging review

@frankaging
Copy link
Copy Markdown
Member

LGTM!

The orthogonal parameterization in LoreftIntervention now correctly
preserves internal state during checkpoint save/load. Previously, only
the computed orthogonal weight was saved, which broke orthogonality
during training continuation (causing loss spikes).

Changes:
- state_dict now saves rotate_layer_original and rotate_layer_base
- load_state_dict restores full parametrization state when available
- Backwards compatible: legacy checkpoints still work for inference
- Add optional debug flag for metrics logging (off by default)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@aryamanarora aryamanarora force-pushed the fix/orthogonal-save-load branch from fe2cfc5 to d09c5bd Compare January 13, 2026 07:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants