Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions pyreft/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,44 +42,97 @@ def __init__(self, **kwargs):
kwargs["dtype"] if "dtype" in kwargs else torch.bfloat16)
self.dropout = torch.nn.Dropout(kwargs["dropout"] if "dropout" in kwargs else 0.0)
self.act_fn = ACT2FN["linear"] if "act_fn" not in kwargs or kwargs["act_fn"] is None else ACT2FN[kwargs["act_fn"]]

# Debug logging (off by default)
self.debug = kwargs.get("debug", False)
self._debug_logged = False
self.metrics = {}
# Save full parametrization state for training continuation (off by default for smaller files)
self.save_for_training = kwargs.get("save_for_training", False)

def forward(
self, base, source=None, subspaces=None
):
rotated_base = self.rotate_layer(base)
output = base + torch.matmul(
(self.act_fn(self.learned_source(base)) - rotated_base), self.rotate_layer.weight.T
)
learned = self.act_fn(self.learned_source(base))
diff = learned - rotated_base
delta = torch.matmul(diff, self.rotate_layer.weight.T)

# Store metrics for logging (only if debug=True)
if self.debug:
# Note: ||delta|| = ||diff|| since R has orthonormal columns
diff_norm = diff.norm().item()
base_norm = base.norm().item()
b_norm = self.learned_source.bias.norm().item()
self.metrics = {
"base_norm": base_norm,
"rotated_base_norm": rotated_base.norm().item(),
"learned_norm": learned.norm().item(),
"b_norm": b_norm,
"diff_norm": diff_norm,
"delta_base_ratio": diff_norm / (base_norm + 1e-8),
}
if not self._debug_logged:
print(f"[DEBUG LoreftIntervention] First forward:")
print(f" base norm: {base_norm:.4f}, Rh norm: {self.metrics['rotated_base_norm']:.4f}")
print(f" Wh+b norm: {self.metrics['learned_norm']:.4f}, b norm: {b_norm:.4f}")
print(f" diff norm: {diff_norm:.4f}, delta/base ratio: {self.metrics['delta_base_ratio']:.4f}")
self._debug_logged = True

output = base + delta
return self.dropout(output.to(base.dtype))

def state_dict(self, *args, **kwargs):
"""
Overwrite for data-efficiency.
Save state for checkpoint.

By default, only saves the computed orthogonal weight (for inference).
If save_for_training=True, also saves internal parametrization state
needed for training continuation without breaking orthogonality.
"""
state_dict = OrderedDict()
for k, v in self.learned_source.state_dict().items():
state_dict[k] = v
# Save computed orthogonal weight (always, for inference)
state_dict["rotate_layer"] = self.rotate_layer.weight.data
# Optionally save internal parametrization state (for training continuation)
if self.save_for_training:
state_dict["rotate_layer_original"] = self.rotate_layer.parametrizations.weight.original.data
state_dict["rotate_layer_base"] = self.rotate_layer.parametrizations.weight[0].base.data
return state_dict

def load_state_dict(self, state_dict, *args, **kwargs):
"""
Overwrite for data-efficiency.
Load state from checkpoint. Properly restores parametrization internals
to support training continuation without breaking orthogonality.
"""
self.learned_source.load_state_dict(state_dict, strict=False)

# Caveat: without creating a new layer, it might not work (still not sure why)
# We have to recreate a layer, and load back the columns.
overload_w = state_dict["rotate_layer"].to(
self.learned_source.weight.device)
overload_w_width = overload_w.shape[-1]

# Recreate the parametrized layer
rotate_layer = LowRankRotateLayer(
self.embed_dim, overload_w_width, init_orth=True).to(
self.learned_source.weight.device)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w
assert torch.allclose(self.rotate_layer.weight.data, overload_w.data) == True # we must match!


# Check if we have the full parametrization state (new format)
if "rotate_layer_original" in state_dict and "rotate_layer_base" in state_dict:
# Restore internal parametrization state for training continuation
original = state_dict["rotate_layer_original"].to(self.learned_source.weight.device)
base = state_dict["rotate_layer_base"].to(self.learned_source.weight.device)
self.rotate_layer.parametrizations.weight.original.data.copy_(original)
self.rotate_layer.parametrizations.weight[0].base.data.copy_(base)
else:
# Legacy format: only has computed weight, use old behavior
# (works for inference but may break training continuation)
self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w

# Verify the loaded weight matches expected
assert torch.allclose(self.rotate_layer.weight.data, overload_w.data, atol=1e-5), \
f"Loaded weight mismatch: {torch.norm(self.rotate_layer.weight.data - overload_w.data)}"

return


Expand Down