Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions examples/models/qwen3/qwen3_32b_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,9 @@ def qwen3_decode(

# Stage 2: QK matmul for all active sb blocks.
all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32)
all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16)
all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32)
all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32)
all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
Expand All @@ -268,16 +267,17 @@ def qwen3_decode(
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])

# Stage 3: softmax for all active sb blocks.
# Stage 3: softmax + SV matmul for all active sb blocks.
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
scores_valid = pl.slice(
all_raw_scores,
[Q_HEAD_BATCH, SEQ_TILE],
[Q_HEAD_PAD, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_BATCH, valid_len],
valid_shape=[Q_HEAD_PAD, valid_len],
)
scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min)
scores = pl.mul(scores_padded, attn_scale)
Expand All @@ -286,41 +286,29 @@ def qwen3_decode(
exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16)
exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32)
cur_li = pl.row_sum(exp_scores_fp32)
all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0])
all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0])
all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0])

# Stage 4: SV matmul for all active sb blocks.
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH):
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
exp_tile = pl.slice(
all_exp_padded,
[Q_HEAD_PAD, SEQ_TILE],
[sb * Q_HEAD_PAD, 0],
)
v_tile = pl.slice(
v_cache,
[SEQ_TILE, head_dim],
[cache_row0, 0],
)
oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32)
oi_tmp = pl.matmul(exp_scores_bf16, v_tile, out_dtype=pl.FP32)
all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0])
all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_PAD, 0])
all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_PAD, 0])

# Stage 5: online softmax accumulation and normalisation.
# Stage 4: online softmax accumulation and normalisation.
with pl.at(level=pl.Level.CORE_GROUP):
oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0])
mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0])
li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0])
for sb in pl.range(1, ctx_blocks):
oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0])
cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0])
mi_new = pl.maximum(mi, cur_mi)
cur_mi_valid = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_PAD, 0])
cur_li_valid = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_PAD, 0])
mi_new = pl.maximum(mi, cur_mi_valid)
alpha = pl.exp(pl.sub(mi, mi_new))
beta = pl.exp(pl.sub(cur_mi, mi_new))
li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li))
beta = pl.exp(pl.sub(cur_mi_valid, mi_new))
li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li_valid))
oi = pl.add(pl.row_expand_mul(oi, alpha),
pl.row_expand_mul(oi_tmp_valid, beta))
mi = mi_new
Expand Down
Loading