[Pallas] Use LONG_INT_TYPE for jagged offsets in examples and tests#2132
Draft
norx1991 wants to merge 1 commit into
Draft
[Pallas] Use LONG_INT_TYPE for jagged offsets in examples and tests#2132norx1991 wants to merge 1 commit into
norx1991 wants to merge 1 commit into
Conversation
b87bc3b to
ea23837
Compare
Extend the LONG_INT_TYPE pattern from #1950 (cross_entropy) to all jagged examples and their tests. Jagged offset tensors are now int32 on Pallas/TPU and int64 elsewhere. torch.cumsum on int32 silently promotes to int64, so the dtype= kwarg is also passed to cumsum to keep offsets in LONG_INT_TYPE. This unblocks the int64 input rejection in Pallas (introduced in #1950) for the jagged tests; remaining xfails now hit their originally-documented JAX tracer / BlockSpec errors.
ea23837 to
64ba3a9
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to #1950 (which introduced
LONG_INT_TYPEand applied it tocross_entropy). Extends the pattern to all jagged examples and their tests so offset tensors are int32 on Pallas/TPU and int64 elsewhere.torch.cumsumon int32 silently promotes to int64, sodtype=is also passed tocumsumto keep offsets inLONG_INT_TYPE.This unblocks the int64 input rejection in Pallas for the jagged tests; remaining xfails now hit their originally-documented JAX tracer / BlockSpec errors instead of the int64 rejection.