xLSTMBlockStack.forward: remove **kwargs to enable torch.jit.script#119
Open
juergengp wants to merge 1 commit into
Open
xLSTMBlockStack.forward: remove **kwargs to enable torch.jit.script#119juergengp wants to merge 1 commit into
juergengp wants to merge 1 commit into
Conversation
The prior `forward(self, x, **kwargs)` signature blocked
`torch.jit.script(xLSTMBlockStack(...))` because TorchScript requires
explicit (non-variadic) signatures on all forward methods. This was
surfaced by a downstream production use case (Eldric AI OS): C++
inference via libtorch needs a scriptable module to load the saved
artifact.
The kwargs were forwarded blindly through:
xLSTMBlockStack.forward(x, **kw)
-> xLSTMBlock.forward(x, **kw)
-> xlstm.forward(x, **kw) (mLSTM / sLSTM)
-> ffn.forward(x, **kw) (gated MLP)
Neither inner forward documents any non-empty kwargs argument; the
chain was an unused passthrough. Callers that need to thread state
through forward should use `step()` instead, which keeps its
explicit `state: dict` parameter and the per-block kwargs path
intact.
Trade-off: in the unlikely case anyone was passing kwargs into the
top-level forward via the dict-spread pattern, this is a breaking
change. None of the published xlstm examples / tutorials / NX-AI's
own training scripts do this — `model(x)` is the universal call shape.
Added tests/test_torchscript_compat.py with three regression checks:
1. torch.jit.script(model) doesn't raise
2. scripted forward output matches eager forward within fp32 epsilon
3. save -> load -> forward roundtrip is byte-stable
Closes the hard blocker on TorchScript-based native C++ inference
for downstream users. JAX/XLA users (NX-AI's preferred path) are
unaffected — the JAX-side xlstm_jax repo doesn't use this PyTorch
forward.
|
I have read the CLA Document and I hereby sign the CLA You can retrigger this bot by commenting recheck in this Pull Request. Posted by the CLA Assistant Lite bot. |
Member
|
Looks good, but then remove the *kwargs from all forward methods. Also the test and comment can be removed, I don't see why there was the kwargs in the first place. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
xLSTMBlockStack.forward(self, x, **kwargs)blockstorch.jit.script(model)because TorchScript requires explicit (non-variadic) signatures on all forward methods. This PR drops the**kwargsparameter from the top-levelforwardso the module can be scripted; thestep()method (which has an explicitstate: dictparameter) is unaffected.Motivation
Downstream production use case (Eldric AI OS, 5.0 → 5.1 line): native C++ inference via libtorch needs the saved
.ptartifact to come from a scriptable module. Without this PR the workaround is to monkey-patch the signature at export time; with the PR, users can calltorch.jit.script(xLSTMBlockStack(cfg))directly.We ran a reconnaissance probe (TorchScript path on Mac arm64; CPU fp32;
parallel_stabilizedmLSTM backend) before opening this PR — the script-compile fails on the variadic**kwargsfirst and on nothing else. Removing the**kwargsunblocks the whole TorchScript path. (Probe report on file with the maintainer if useful.)What changed
xlstm/xlstm_block_stack.py:forward(self, x, **kwargs)→forward(self, x). Inner-block call switched fromblock(x, **kwargs)toblock(x).tests/test_torchscript_compat.py(new): three regression checks:torch.jit.script(model)doesn't raiseTrade-off
In the unlikely case anyone passes kwargs into the top-level
forwardvia dict-spread, this is a breaking change. None of the publishedxlstmexamples / tutorials / NX-AI's own training scripts do this —model(x)is the universal call shape. Callers that thread state through forward should usestep(), which keeps its explicitstate: dictparameter and the per-block kwargs path intact.If preserving the variadic signature is a hard requirement, an alternative is
model.scriptable_forward(x)as a TorchScript-compatible sibling, leavingforward(x, **kwargs)untouched. Happy to iterate to whichever shape is preferred.Tests
All three checks pass against this branch + fail against
main(the script-compile step raises on**kwargs).Compatibility
JAX users are unaffected —
xlstm_jaxdoesn't use this PyTorchforward.