Skip to content

xLSTMBlockStack.forward: remove **kwargs to enable torch.jit.script#119

Open
juergengp wants to merge 1 commit into
NX-AI:mainfrom
juergengp:torchscript-compat-kwargs-removal
Open

xLSTMBlockStack.forward: remove **kwargs to enable torch.jit.script#119
juergengp wants to merge 1 commit into
NX-AI:mainfrom
juergengp:torchscript-compat-kwargs-removal

Conversation

@juergengp
Copy link
Copy Markdown

Summary

xLSTMBlockStack.forward(self, x, **kwargs) blocks torch.jit.script(model) because TorchScript requires explicit (non-variadic) signatures on all forward methods. This PR drops the **kwargs parameter from the top-level forward so the module can be scripted; the step() method (which has an explicit state: dict parameter) is unaffected.

Motivation

Downstream production use case (Eldric AI OS, 5.0 → 5.1 line): native C++ inference via libtorch needs the saved .pt artifact to come from a scriptable module. Without this PR the workaround is to monkey-patch the signature at export time; with the PR, users can call torch.jit.script(xLSTMBlockStack(cfg)) directly.

We ran a reconnaissance probe (TorchScript path on Mac arm64; CPU fp32; parallel_stabilized mLSTM backend) before opening this PR — the script-compile fails on the variadic **kwargs first and on nothing else. Removing the **kwargs unblocks the whole TorchScript path. (Probe report on file with the maintainer if useful.)

What changed

  • xlstm/xlstm_block_stack.py: forward(self, x, **kwargs)forward(self, x). Inner-block call switched from block(x, **kwargs) to block(x).
  • tests/test_torchscript_compat.py (new): three regression checks:
    1. torch.jit.script(model) doesn't raise
    2. scripted forward output matches eager forward within fp32 epsilon
    3. save → load → forward roundtrip byte-stable

Trade-off

In the unlikely case anyone passes kwargs into the top-level forward via dict-spread, this is a breaking change. None of the published xlstm examples / tutorials / NX-AI's own training scripts do this — model(x) is the universal call shape. Callers that thread state through forward should use step(), which keeps its explicit state: dict parameter and the per-block kwargs path intact.

If preserving the variadic signature is a hard requirement, an alternative is model.scriptable_forward(x) as a TorchScript-compatible sibling, leaving forward(x, **kwargs) untouched. Happy to iterate to whichever shape is preferred.

Tests

$ pytest tests/test_torchscript_compat.py -v

All three checks pass against this branch + fail against main (the script-compile step raises on **kwargs).

Compatibility

JAX users are unaffected — xlstm_jax doesn't use this PyTorch forward.

The prior `forward(self, x, **kwargs)` signature blocked
`torch.jit.script(xLSTMBlockStack(...))` because TorchScript requires
explicit (non-variadic) signatures on all forward methods. This was
surfaced by a downstream production use case (Eldric AI OS): C++
inference via libtorch needs a scriptable module to load the saved
artifact.

The kwargs were forwarded blindly through:
  xLSTMBlockStack.forward(x, **kw)
    -> xLSTMBlock.forward(x, **kw)
        -> xlstm.forward(x, **kw)     (mLSTM / sLSTM)
        -> ffn.forward(x, **kw)       (gated MLP)

Neither inner forward documents any non-empty kwargs argument; the
chain was an unused passthrough. Callers that need to thread state
through forward should use `step()` instead, which keeps its
explicit `state: dict` parameter and the per-block kwargs path
intact.

Trade-off: in the unlikely case anyone was passing kwargs into the
top-level forward via the dict-spread pattern, this is a breaking
change. None of the published xlstm examples / tutorials / NX-AI's
own training scripts do this — `model(x)` is the universal call shape.

Added tests/test_torchscript_compat.py with three regression checks:
  1. torch.jit.script(model) doesn't raise
  2. scripted forward output matches eager forward within fp32 epsilon
  3. save -> load -> forward roundtrip is byte-stable

Closes the hard blocker on TorchScript-based native C++ inference
for downstream users. JAX/XLA users (NX-AI's preferred path) are
unaffected — the JAX-side xlstm_jax repo doesn't use this PyTorch
forward.
@github-actions
Copy link
Copy Markdown


Thank you for your submission, we really appreciate it. Like many open-source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution. You can sign the CLA by just posting a Pull Request Comment same as the below format.


I have read the CLA Document and I hereby sign the CLA


You can retrigger this bot by commenting recheck in this Pull Request. Posted by the CLA Assistant Lite bot.

@martinloretzzz
Copy link
Copy Markdown
Member

Looks good, but then remove the *kwargs from all forward methods. Also the test and comment can be removed, I don't see why there was the kwargs in the first place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants