Skip to content

A jagged_gdpa example that works on Pallas TPU#2425

Open
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/40from
AmesingFlank/stack/52
Open

A jagged_gdpa example that works on Pallas TPU#2425
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/40from
AmesingFlank/stack/52

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented May 14, 2026

Stacked PRs:


A jagged_gdpa example that works on Pallas TPU

Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

AmesingFlank added a commit that referenced this pull request May 14, 2026
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from c1cef2e to 3a393ea Compare May 14, 2026 02:12
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 14, 2026
AmesingFlank added a commit that referenced this pull request May 14, 2026
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank marked this pull request as draft May 14, 2026 02:18
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 14, 2026 02:18
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from 3a393ea to 6d37bde Compare May 14, 2026 02:18
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 14, 2026 02:18
@AmesingFlank AmesingFlank marked this pull request as ready for review May 14, 2026 02:18
AmesingFlank added a commit that referenced this pull request May 14, 2026
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank marked this pull request as draft May 14, 2026 03:00
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 14, 2026 03:00
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from 6d37bde to b3208cb Compare May 14, 2026 03:01
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 14, 2026 03:01
@AmesingFlank AmesingFlank marked this pull request as ready for review May 14, 2026 03:01
AmesingFlank added a commit that referenced this pull request May 14, 2026
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank marked this pull request as draft May 14, 2026 18:38
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 14, 2026 18:38
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from b3208cb to acc2bf4 Compare May 14, 2026 18:38
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 14, 2026 18:38
@AmesingFlank AmesingFlank marked this pull request as ready for review May 14, 2026 18:38
AmesingFlank added a commit that referenced this pull request May 14, 2026
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank marked this pull request as draft May 14, 2026 18:40
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 14, 2026 18:40
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from acc2bf4 to 4debb73 Compare May 14, 2026 18:40
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 14, 2026 18:40
@AmesingFlank AmesingFlank marked this pull request as ready for review May 14, 2026 18:40
AmesingFlank added a commit that referenced this pull request May 15, 2026
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank marked this pull request as draft May 15, 2026 01:29
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 15, 2026 01:29
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from 4debb73 to c7a07ed Compare May 15, 2026 01:29
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 15, 2026 01:29
@AmesingFlank AmesingFlank marked this pull request as ready for review May 15, 2026 01:30
Add a new jagged gdpa example that works Pallas backends

This kernel works on Pallas TPU, because it doesn't depend
on any integer indexing, but instead, it uses tiled access on a jagged
dimension, which translates to DMA that dynamically slices the -3th
dimension, while loading the -2th and -1th dimension in full
(pltpu.make_async_copy(q[pl.ds(start, size), :, :]), which is the same
way RaggedPagedAttention handles jagged data)

stack-info: PR: #2425, branch: AmesingFlank/stack/52
@AmesingFlank AmesingFlank marked this pull request as draft May 15, 2026 20:35
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 15, 2026 20:35
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/52 branch from c7a07ed to cbd48b0 Compare May 15, 2026 20:36
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 15, 2026 20:36
@AmesingFlank AmesingFlank marked this pull request as ready for review May 15, 2026 20:36
@AmesingFlank AmesingFlank marked this pull request as draft May 15, 2026 20:50
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 15, 2026 20:51
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 15, 2026 20:51
@AmesingFlank AmesingFlank marked this pull request as ready for review May 15, 2026 20:51
@AmesingFlank AmesingFlank marked this pull request as draft May 15, 2026 21:16
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/40 to main May 15, 2026 21:16
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/40 May 15, 2026 21:17
@AmesingFlank AmesingFlank marked this pull request as ready for review May 15, 2026 21:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants