A jagged_gdpa example that works on Pallas TPU#2425
Open
AmesingFlank wants to merge 1 commit into
Open
Conversation
AmesingFlank
added a commit
that referenced
this pull request
May 14, 2026
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
c1cef2e to
3a393ea
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 14, 2026
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
3a393ea to
6d37bde
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 14, 2026
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
6d37bde to
b3208cb
Compare
jansel
approved these changes
May 14, 2026
AmesingFlank
added a commit
that referenced
this pull request
May 14, 2026
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
b3208cb to
acc2bf4
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 14, 2026
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
acc2bf4 to
4debb73
Compare
AmesingFlank
added a commit
that referenced
this pull request
May 15, 2026
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
4debb73 to
c7a07ed
Compare
Add a new jagged gdpa example that works Pallas backends This kernel works on Pallas TPU, because it doesn't depend on any integer indexing, but instead, it uses tiled access on a jagged dimension, which translates to DMA that dynamically slices the -3th dimension, while loading the -2th and -1th dimension in full (pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same way RaggedPagedAttention handles jagged data) stack-info: PR: #2425, branch: AmesingFlank/stack/52
c7a07ed to
cbd48b0
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
A jagged_gdpa example that works on Pallas TPU
Add a new jagged gdpa example that works Pallas backends
This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)