diff --git a/examples/models/qwen3/qwen3_32b_decode.py b/examples/models/qwen3/qwen3_32b_decode.py index 343b155..712a5b3 100644 --- a/examples/models/qwen3/qwen3_32b_decode.py +++ b/examples/models/qwen3/qwen3_32b_decode.py @@ -252,10 +252,9 @@ def qwen3_decode( # Stage 2: QK matmul for all active sb blocks. all_raw_scores = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.FP32) - all_exp_padded = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, SEQ_TILE], dtype=pl.BF16) all_oi_tmp = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) - all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) - all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_BATCH, 1], dtype=pl.FP32) + all_cur_mi = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): s0 = sb * SEQ_TILE @@ -268,16 +267,17 @@ def qwen3_decode( raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) - # Stage 3: softmax for all active sb blocks. + # Stage 3: softmax + SV matmul for all active sb blocks. with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 valid_len = pl.min(SEQ_TILE, ctx_len - s0) scores_valid = pl.slice( all_raw_scores, - [Q_HEAD_BATCH, SEQ_TILE], + [Q_HEAD_PAD, SEQ_TILE], [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_BATCH, valid_len], + valid_shape=[Q_HEAD_PAD, valid_len], ) scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) scores = pl.mul(scores_padded, attn_scale) @@ -286,41 +286,29 @@ def qwen3_decode( exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) cur_li = pl.row_sum(exp_scores_fp32) - all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) - all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH, 0]) - all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_BATCH, 0]) - - # Stage 4: SV matmul for all active sb blocks. - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - exp_tile = pl.slice( - all_exp_padded, - [Q_HEAD_PAD, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - ) v_tile = pl.slice( v_cache, [SEQ_TILE, head_dim], [cache_row0, 0], ) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) + oi_tmp = pl.matmul(exp_scores_bf16, v_tile, out_dtype=pl.FP32) all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) + all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_PAD, 0]) + all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_PAD, 0]) - # Stage 5: online softmax accumulation and normalisation. + # Stage 4: online softmax accumulation and normalisation. with pl.at(level=pl.Level.CORE_GROUP): oi = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [0, 0]) mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [0, 0]) li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [0, 0]) for sb in pl.range(1, ctx_blocks): oi_tmp_valid = pl.slice(all_oi_tmp, [Q_HEAD_BATCH, head_dim], [sb * Q_HEAD_PAD, 0]) - cur_mi = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) - cur_li = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_BATCH, 0]) - mi_new = pl.maximum(mi, cur_mi) + cur_mi_valid = pl.slice(all_cur_mi, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_PAD, 0]) + cur_li_valid = pl.slice(all_cur_li, [Q_HEAD_BATCH, 1], [sb * Q_HEAD_PAD, 0]) + mi_new = pl.maximum(mi, cur_mi_valid) alpha = pl.exp(pl.sub(mi, mi_new)) - beta = pl.exp(pl.sub(cur_mi, mi_new)) - li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) + beta = pl.exp(pl.sub(cur_mi_valid, mi_new)) + li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li_valid)) oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp_valid, beta)) mi = mi_new