Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 213 additions & 11 deletions examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@
from __future__ import annotations

"""
DeepSeek V3.2-EXP single-layer decode FRONT part — Scope 1 only (batch=16).
DeepSeek V3.2-EXP single-layer decode front path (batch=16).

Scope 1: input RMSNorm + Q/KV projection.
Projection stage:
- Compute RMSNorm of hidden_states
- Project to Q latent (qr) via wq_a
- Apply q_norm to Q latent, then project to Q heads (q_proj) via wq_b
- Apply q_norm to Q latent, then project to Q heads via wq_b
- Project to KV latent (kv_a) via wkv_a

Decode cache preparation:
- Split Q heads into q_nope and q_pe outputs
- Apply RoPE to q_pe and k_pe
- Apply kv_norm to KV latent, then update KV/PE cache for the current token

Aligned to official v3.2-exp MLA shapes:
- qk_nope_head_dim = 128
- qk_rope_head_dim = 64
Expand All @@ -28,6 +33,7 @@


BATCH = 16
MAX_SEQ = 4096
HIDDEN = 7168
NUM_HEADS = 128
Q_LORA_RANK = 1536
Expand All @@ -37,11 +43,12 @@
QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM
V_HEAD_DIM = 128
KV_A_OUT = KV_LORA_RANK + QK_ROPE_HEAD_DIM
CACHE_ROWS = BATCH * MAX_SEQ

EPS = 1e-6
HIDDEN_INV = 1.0 / HIDDEN

# Tile sizes tuned for standalone scope-1 incore boundaries:
# Tile sizes tuned for the standalone front projection/cache-prep example:
# - PROJ_K = K-dimension chunk for projection matmuls (kept at 512).
# - LORA_CHUNK, KV_OUT_CHUNK = 64 so AIC Right buffer ≤ 65536
# (512 * 64 * 2 = 65536).
Expand All @@ -52,7 +59,7 @@
# static_assert in pto_tile.hpp. The original 3-scope pipeline used
# BATCH_TILE=4 because the combined scopes allowed a different split.
# - LOCAL_PAD_WIDTH removed; the pad tensor was a tuning hint for the
# combined scope1+2+3 pipeline and is not needed for scope1 alone.
# combined front pipeline and is not needed here.
RMSNORM_K = 512
PROJ_K = 512
Q_OUT_CHUNK = 64
Expand All @@ -63,6 +70,7 @@

def build_deepseek_v3_2_decode_front_scope1_program(
batch: int = BATCH,
max_seq_len: int = MAX_SEQ,
hidden_size: int = HIDDEN,
num_heads: int = NUM_HEADS,
q_lora_rank: int = Q_LORA_RANK,
Expand All @@ -72,6 +80,7 @@ def build_deepseek_v3_2_decode_front_scope1_program(
v_head_dim: int = V_HEAD_DIM,
):
BATCH_CFG = batch
MAX_SEQ_CFG = max_seq_len
HIDDEN_CFG = hidden_size
NUM_HEADS_CFG = num_heads
Q_LORA_RANK_CFG = q_lora_rank
Expand All @@ -81,6 +90,7 @@ def build_deepseek_v3_2_decode_front_scope1_program(
QK_HEAD_DIM_CFG = qk_nope_head_dim + qk_rope_head_dim
V_HEAD_DIM_CFG = v_head_dim
KV_A_OUT_CFG = kv_lora_rank + qk_rope_head_dim
CACHE_ROWS_CFG = batch * max_seq_len
Q_LORA_INV_CFG = 1.0 / q_lora_rank

RMSNORM_BLOCKS = (HIDDEN_CFG + RMSNORM_K - 1) // RMSNORM_K
Expand All @@ -100,12 +110,21 @@ def deepseek_v3_2_decode_front_scope1(
q_norm_weight: pl.Tensor[[1, Q_LORA_RANK_CFG], pl.FP32],
wq_b: pl.Tensor[[Q_LORA_RANK_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16],
wkv_a: pl.Tensor[[HIDDEN_CFG, KV_A_OUT_CFG], pl.BF16],
seq_lens: pl.Tensor[[BATCH_CFG], pl.INT32],
rope_cos: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32],
rope_sin: pl.Tensor[[MAX_SEQ_CFG, QK_ROPE_HEAD_DIM_CFG], pl.FP32],
kv_norm_weight: pl.Tensor[[1, KV_LORA_RANK_CFG], pl.FP32],
# Output buffers
qr_out: pl.Tensor[[BATCH_CFG, Q_LORA_RANK_CFG], pl.BF16],
q_proj_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16],
q_nope_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_NOPE_HEAD_DIM_CFG], pl.BF16],
q_pe_out: pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_ROPE_HEAD_DIM_CFG], pl.BF16],
kv_a_out: pl.Tensor[[BATCH_CFG, KV_A_OUT_CFG], pl.BF16],
) -> pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], pl.BF16]:
# Scope 1: input RMSNorm + Q/KV projection.
kv_cache: pl.Tensor[[CACHE_ROWS_CFG, KV_LORA_RANK_CFG], pl.BF16],
pe_cache: pl.Tensor[[CACHE_ROWS_CFG, QK_ROPE_HEAD_DIM_CFG], pl.BF16],
) -> pl.Tensor[[BATCH_CFG, NUM_HEADS_CFG * QK_ROPE_HEAD_DIM_CFG], pl.BF16]:
q_proj_out = pl.create_tensor([BATCH_CFG, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16)

# Front projection: input RMSNorm + Q/KV projections.
for b0 in pl.range(0, BATCH_CFG, BATCH_TILE):
normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.BF16)
qr_fp32_tile = pl.create_tensor([BATCH_TILE, Q_LORA_RANK_CFG], dtype=pl.FP32)
Expand Down Expand Up @@ -228,7 +247,121 @@ def deepseek_v3_2_decode_front_scope1(
)
kv_a_out = pl.assemble(kv_a_out, kv_chunk, [b0, kv0])

return q_proj_out
kv_normed_out = pl.create_tensor([BATCH_CFG, KV_LORA_RANK_CFG], dtype=pl.BF16)

with pl.at(level=pl.Level.CORE_GROUP):
kv_rows = pl.cast(pl.slice(kv_a_out, [BATCH_CFG, KV_LORA_RANK_CFG], [0, 0]), target_type=pl.FP32)
kv_partial_sq = pl.reshape(pl.row_sum(pl.mul(kv_rows, kv_rows)), [1, BATCH_CFG])
kv_variance = pl.reshape(
pl.add(pl.mul(kv_partial_sq, 1.0 / KV_LORA_RANK_CFG), EPS),
[BATCH_CFG, 1],
)
kv_inv_rms = pl.recip(pl.sqrt(kv_variance))
kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK_CFG], [0, 0])
kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_rows, kv_inv_rms), kv_gamma)
kv_normed_out = pl.assemble(kv_normed_out, pl.cast(kv_normed, target_type=pl.BF16), [0, 0])

# Q split + q_rope: produce q_nope/q_pe while keeping the internal
# full Q layout available for the in-place RoPE writes.
# Cache preparation: write RMS-normalized KV latent and rotated
# k_pe for the current decode token.
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for b in pl.parallel(0, BATCH_CFG, 1, chunk=4):
ctx_len = pl.tensor.read(seq_lens, [b])
pos = ctx_len - 1
cache_row = b * MAX_SEQ_CFG + pos

cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
cos_hi = pl.slice(
rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)
sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
sin_hi = pl.slice(
rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)

for h in pl.range(NUM_HEADS_CFG):
q_col = h * QK_HEAD_DIM_CFG
q_nope_col = h * QK_NOPE_HEAD_DIM_CFG
q_pe_col = h * QK_ROPE_HEAD_DIM_CFG
q_nope = pl.slice(q_proj_out, [1, QK_NOPE_HEAD_DIM_CFG], [b, q_col])
q_nope_out = pl.assemble(q_nope_out, q_nope, [b, q_nope_col])
q_lo = pl.cast(
pl.slice(
q_proj_out,
[1, QK_ROPE_HEAD_DIM_CFG // 2],
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
),
target_type=pl.FP32,
)
q_hi = pl.cast(
pl.slice(
q_proj_out,
[1, QK_ROPE_HEAD_DIM_CFG // 2],
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
),
target_type=pl.FP32,
)
q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo))
q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi))
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_lo, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
)
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
)
Comment on lines +307 to +316
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pl.assemble calls updating q_proj_out with rotated values are redundant. q_proj_out is a local temporary tensor created at line 125 and is not used after this loop (the function returns q_pe_out). Removing these unnecessary Global Memory writes will improve performance.

                        q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
                        q_pe_out = pl.assemble(
                            q_pe_out,
                            pl.cast(q_rot_hi, target_type=pl.BF16),
                            [b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
                        )

Comment on lines +307 to +316
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Dead writes: rotated q_pe values stored back into local q_proj_out are never read.

q_proj_out is a local tensor (line 125) and is not consumed after this loop — return is q_pe_out, and the second parallel loop (lines 324-362) doesn't read from q_proj_out. The in-place RoPE write-back into the PE slice of q_proj_out is therefore wasted work inside a tight per-head loop. The comment at 264-265 about "keeping the internal full Q layout available for the in-place RoPE writes" no longer reflects the actual data flow — q_pe_out is assembled directly from q_rot_lo/q_rot_hi two lines below.

🧹 Proposed cleanup
                         q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo))
                         q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi))
-                        q_proj_out = pl.assemble(
-                            q_proj_out,
-                            pl.cast(q_rot_lo, target_type=pl.BF16),
-                            [b, q_col + QK_NOPE_HEAD_DIM_CFG],
-                        )
-                        q_proj_out = pl.assemble(
-                            q_proj_out,
-                            pl.cast(q_rot_hi, target_type=pl.BF16),
-                            [b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
-                        )
                         q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
                         q_pe_out = pl.assemble(
                             q_pe_out,
                             pl.cast(q_rot_hi, target_type=pl.BF16),
                             [b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
                         )

And refresh the comment block around 264-265 accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_lo, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG],
)
q_proj_out = pl.assemble(
q_proj_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_col + QK_NOPE_HEAD_DIM_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/models/deepseek_v3_2/deepseek_v3_2_decode_front_scope1.py` around
lines 307 - 316, The writes that assemble q_rot_lo/q_rot_hi back into the local
q_proj_out buffer (via q_proj_out = pl.assemble(...)) are dead because
q_proj_out is never read later; only q_pe_out is returned and used — remove the
two in-place RoPE assemble calls that write to q_proj_out (the ones using
q_rot_lo and q_rot_hi with QK_NOPE_HEAD_DIM_CFG and QK_ROPE_HEAD_DIM_CFG
offsets) and clean up the surrounding comment near the earlier "keeping the
internal full Q layout" notes to reflect that we now directly assemble q_pe_out
from q_rot_lo/q_rot_hi instead of maintaining a full Q layout in q_proj_out.
Ensure references to q_rot_lo, q_rot_hi, q_proj_out, q_pe_out,
QK_NOPE_HEAD_DIM_CFG and QK_ROPE_HEAD_DIM_CFG are updated or removed in comments
as appropriate.

q_pe_out = pl.assemble(q_pe_out, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_pe_col])
q_pe_out = pl.assemble(
q_pe_out,
pl.cast(q_rot_hi, target_type=pl.BF16),
[b, q_pe_col + QK_ROPE_HEAD_DIM_CFG // 2],
)

with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for b in pl.parallel(0, BATCH_CFG, 1, chunk=4):
ctx_len = pl.tensor.read(seq_lens, [b])
pos = ctx_len - 1
cache_row = b * MAX_SEQ_CFG + pos

cos_lo = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
cos_hi = pl.slice(
rope_cos, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)
sin_lo = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, 0])
sin_hi = pl.slice(
rope_sin, [1, QK_ROPE_HEAD_DIM_CFG // 2], [pos, QK_ROPE_HEAD_DIM_CFG // 2]
)
kv_normed_row = pl.slice(kv_normed_out, [1, KV_LORA_RANK_CFG], [b, 0])

pe_lo = pl.cast(
pl.slice(kv_a_out, [1, QK_ROPE_HEAD_DIM_CFG // 2], [b, KV_LORA_RANK_CFG]),
target_type=pl.FP32,
)
pe_hi = pl.cast(
pl.slice(
kv_a_out,
[1, QK_ROPE_HEAD_DIM_CFG // 2],
[b, KV_LORA_RANK_CFG + QK_ROPE_HEAD_DIM_CFG // 2],
),
target_type=pl.FP32,
)

pe_rot_lo = pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo))
pe_rot_hi = pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi))

kv_cache = pl.assemble(kv_cache, kv_normed_row, [cache_row, 0])
pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot_lo, target_type=pl.BF16), [cache_row, 0])
pe_cache = pl.assemble(
pe_cache,
pl.cast(pe_rot_hi, target_type=pl.BF16),
[cache_row, QK_ROPE_HEAD_DIM_CFG // 2],
)

Comment on lines +268 to +363
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The two parallel loops over BATCH_CFG (lines 269 and 325) can be fused into a single loop. This optimization avoids redundant reads of seq_lens and redundant slicing of rope_cos/rope_sin tensors, reducing overhead and improving cache locality.

return q_pe_out

return DeepSeekV32DecodeFrontScope1

Expand All @@ -242,6 +375,12 @@ def golden_decode_front_scope1(tensors):
q_norm_weight = tensors["q_norm_weight"].float()
wq_b = tensors["wq_b"].float()
wkv_a = tensors["wkv_a"].float()
seq_lens = tensors["seq_lens"]
rope_cos = tensors["rope_cos"].float()
rope_sin = tensors["rope_sin"].float()
kv_norm_weight = tensors["kv_norm_weight"].float()
kv_cache = tensors["kv_cache"]
pe_cache = tensors["pe_cache"]

# RMSNorm
sq_sum = torch.sum(hidden_states * hidden_states, dim=1, keepdim=True)
Expand All @@ -262,12 +401,42 @@ def golden_decode_front_scope1(tensors):

# Write into output tensor slots
tensors["qr_out"].copy_(qr)
tensors["q_proj_out"].copy_(q_proj)
tensors["kv_a_out"].copy_(kv_a)

half = QK_ROPE_HEAD_DIM // 2
q_proj_view = q_proj.float().view(q_proj.shape[0], NUM_HEADS, QK_HEAD_DIM)
for b in range(kv_a.shape[0]):
ctx_len = int(seq_lens[b].item())
pos = ctx_len - 1
cache_row = b * MAX_SEQ + pos

kv_row = kv_a[b : b + 1, :KV_LORA_RANK].float()
kv_var = torch.mean(kv_row * kv_row, dim=-1, keepdim=True)
kv_normed = kv_row * torch.rsqrt(kv_var + EPS) * kv_norm_weight
kv_cache[cache_row : cache_row + 1].copy_(kv_normed.to(torch.bfloat16))

pe_lo = kv_a[b : b + 1, KV_LORA_RANK : KV_LORA_RANK + half].float()
pe_hi = kv_a[b : b + 1, KV_LORA_RANK + half : KV_LORA_RANK + 2 * half].float()
cos_lo = rope_cos[pos : pos + 1, :half]
cos_hi = rope_cos[pos : pos + 1, half:]
sin_lo = rope_sin[pos : pos + 1, :half]
sin_hi = rope_sin[pos : pos + 1, half:]

q_pe = q_proj_view[b, :, QK_NOPE_HEAD_DIM:]
q_lo = q_pe[:, :half].clone()
q_hi = q_pe[:, half:].clone()
q_proj_view[b, :, QK_NOPE_HEAD_DIM : QK_NOPE_HEAD_DIM + half] = q_lo * cos_lo - q_hi * sin_lo
q_proj_view[b, :, QK_NOPE_HEAD_DIM + half :] = q_hi * cos_hi + q_lo * sin_hi

pe_cache[cache_row : cache_row + 1, :half].copy_((pe_lo * cos_lo - pe_hi * sin_lo).to(torch.bfloat16))
pe_cache[cache_row : cache_row + 1, half:].copy_((pe_hi * cos_hi + pe_lo * sin_hi).to(torch.bfloat16))
tensors["q_nope_out"].copy_(q_proj_view[:, :, :QK_NOPE_HEAD_DIM].reshape(q_proj.shape[0], -1).to(torch.bfloat16))
tensors["q_pe_out"].copy_(q_proj_view[:, :, QK_NOPE_HEAD_DIM:].reshape(q_proj.shape[0], -1).to(torch.bfloat16))


def build_tensor_specs(
batch: int = BATCH,
max_seq_len: int = MAX_SEQ,
hidden_size: int = HIDDEN,
num_heads: int = NUM_HEADS,
q_lora_rank: int = Q_LORA_RANK,
Expand All @@ -281,6 +450,8 @@ def build_tensor_specs(

qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
kv_a_out = kv_lora_rank + qk_rope_head_dim
cache_rows = batch * max_seq_len
seq_lens_data = torch.randint(1, max_seq_len + 1, (batch,), dtype=torch.int32)

def init_hidden_states():
return torch.rand(batch, hidden_size) - 0.5
Expand All @@ -300,16 +471,47 @@ def init_wq_b():
def init_wkv_a():
return (torch.rand(hidden_size, kv_a_out) - 0.5) / hidden_size ** 0.5

def init_kv_norm_weight():
return torch.rand(1, kv_lora_rank) - 0.5

def init_rope():
return torch.rand(max_seq_len, qk_rope_head_dim) - 0.5

def init_cache_kv():
return torch.zeros(cache_rows, kv_lora_rank)

def init_cache_pe():
return torch.zeros(cache_rows, qk_rope_head_dim)

return [
TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, init_value=init_hidden_states),
TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, init_value=init_rms_weight),
TensorSpec("wq_a", [hidden_size, q_lora_rank], torch.bfloat16, init_value=init_wq_a),
TensorSpec("q_norm_weight", [1, q_lora_rank], torch.float32, init_value=init_q_norm_weight),
TensorSpec("wq_b", [q_lora_rank, num_heads * qk_head_dim], torch.bfloat16, init_value=init_wq_b),
TensorSpec("wkv_a", [hidden_size, kv_a_out], torch.bfloat16, init_value=init_wkv_a),
TensorSpec("seq_lens", [batch], torch.int32, init_value=seq_lens_data),
TensorSpec("rope_cos", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=init_rope),
TensorSpec("rope_sin", [max_seq_len, qk_rope_head_dim], torch.float32, init_value=init_rope),
TensorSpec("kv_norm_weight", [1, kv_lora_rank], torch.float32, init_value=init_kv_norm_weight),
TensorSpec("qr_out", [batch, q_lora_rank], torch.bfloat16, is_output=True),
TensorSpec("q_proj_out", [batch, num_heads * qk_head_dim], torch.bfloat16, is_output=True),
TensorSpec("q_nope_out", [batch, num_heads * qk_nope_head_dim], torch.bfloat16, is_output=True),
TensorSpec("q_pe_out", [batch, num_heads * qk_rope_head_dim], torch.bfloat16, is_output=True),
TensorSpec("kv_a_out", [batch, kv_a_out], torch.bfloat16, is_output=True),
TensorSpec(
"kv_cache",
[cache_rows, kv_lora_rank],
torch.bfloat16,
init_value=init_cache_kv,
is_output=True,
),
TensorSpec(
"pe_cache",
[cache_rows, qk_rope_head_dim],
torch.bfloat16,
init_value=init_cache_pe,
is_output=True,
),
]


Expand Down
Loading