fix(networks): replace Tensor | None with Optional[Tensor] for TorchScript compatibility#8879
Conversation
…r] for TorchScript compatibility The `|` union type syntax (e.g. `torch.Tensor | None`) was introduced in Python 3.10. While `from __future__ import annotations` defers evaluation at runtime, TorchScript's annotation parser does not support this syntax and fails when scripting models that contain these forward method signatures. Replace `torch.Tensor | None` with `Optional[torch.Tensor]` in the `forward` methods of: - `monai/networks/blocks/crossattention.py` (CrossAttentionBlock) - `monai/networks/blocks/selfattention.py` (SABlock) - `monai/networks/blocks/transformerblock.py` (TransformerBlock) These three blocks are used in the ViT/UNETR scripting path, causing `RuntimeError: Can't redefine method: forward on class` when `torch.jit.script()` is called on a UNETR model. Closes Project-MONAI#7939 Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
|
Hey @holgerroth @wyli @ericspod. Could you, please, have a look at this? |
for more information, see https://pre-commit.ci
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Hi @AlexanderSanin I don't see files changed anymore since pre-commit (using ruff I think) autofixed your changes back to what they were. You may have to create a variable for your type with However, we don't support Python 3.9 anymore so we expect the |
Summary
Fixes #7939
The
|union type syntax (e.g.torch.Tensor | None) was introduced in Python 3.10. Whilefrom __future__ import annotationsdefers annotation evaluation at runtime, TorchScript's annotation parser does not support this syntax and fails when scripting models that include these forward method signatures, producing:This PR replaces
torch.Tensor | NonewithOptional[torch.Tensor](fromtyping) in theforwardmethods of the three blocks that form the ViT/UNETR scripting path:monai/networks/blocks/crossattention.py—CrossAttentionBlock.forwardmonai/networks/blocks/selfattention.py—SABlock.forwardmonai/networks/blocks/transformerblock.py—TransformerBlock.forwardTest plan
torch.jit.script(CrossAttentionBlock(...))succeeds after fixtorch.jit.script(SABlock(...))succeeds after fixtorch.jit.script(TransformerBlock(...))succeeds after fixpython -m pytest tests/networks/nets/test_unetr.py -k test_scriptpython -m pytest tests/networks/blocks/test_crossattention.pypython -m pytest tests/networks/blocks/test_selfattention.py