Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ add_fusilli_samples(
sdpa/sdpa_fprop_with_mask.cpp
sdpa/sdpa_fprop_dropout.cpp
sdpa/sdpa_fprop_gqa.cpp
sdpa/sdpa_fprop_gqa_independent_kv.cpp
sdpa/sdpa_fprop_cross_attn.cpp
DEPS
libfusilli
Expand Down
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_basic_mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
TEST_CASE("SDPA forward: basic MHA f16", "[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/8,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/8,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64);
}
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_causal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TEST_CASE("SDPA forward: causal f16", "[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/8,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/8,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64,
/*isCausal=*/true);
}
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_cross_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
TEST_CASE("SDPA forward: cross attention f16", "[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/8,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/8,
/*seqQ=*/32, /*seqKV=*/128, /*headDim=*/64);
}
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_custom_scale.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TEST_CASE("SDPA forward: custom scale f16", "[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/8,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/8,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64,
/*isCausal=*/false, /*scale=*/0.125f);
}
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ TEST_CASE("SDPA forward: dropout f16",
"[sdpa][custom_op][graph][!shouldfail]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/8,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/8,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64,
/*isCausal=*/false, /*scale=*/std::nullopt,
/*enableGqa=*/false, /*hasAttnMask=*/false,
Expand Down
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_gqa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TEST_CASE("SDPA forward: GQA f16", "[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/2,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/2,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64,
/*isCausal=*/false, /*scale=*/std::nullopt,
/*enableGqa=*/true);
Expand Down
25 changes: 25 additions & 0 deletions samples/sdpa/sdpa_fprop_gqa_independent_kv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2026 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <fusilli.h>

#include "sdpa_utils.h"
#include "utils.h"

#include <catch2/catch_test_macros.hpp>

#include <optional>

TEST_CASE("SDPA forward: GQA independent K/V heads f16",
"[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/4,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64,
/*isCausal=*/false, /*scale=*/std::nullopt,
/*enableGqa=*/true, /*hasAttnMask=*/false,
/*dropoutP=*/0.0f, /*headsV=*/2);
}
2 changes: 1 addition & 1 deletion samples/sdpa/sdpa_fprop_with_mask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TEST_CASE("SDPA forward: with attention mask f16", "[sdpa][custom_op][graph]") {
FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend));
executeSdpa(handle, DataType::Half,
/*batch=*/1, /*headsQ=*/8, /*headsKV=*/8,
/*batch=*/1, /*headsQ=*/8, /*headsK=*/8,
/*seqQ=*/64, /*seqKV=*/64, /*headDim=*/64,
/*isCausal=*/false, /*scale=*/std::nullopt,
/*enableGqa=*/false, /*hasAttnMask=*/true);
Expand Down
55 changes: 34 additions & 21 deletions samples/sdpa/sdpa_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,27 @@ static std::string buildSdpaMlir(bool hasAttnMask = false,
// CPU reference implementation of scaled dot-product attention.
// Computes SDPA in float precision for numerical verification against the GPU.
// Layout: [batch, heads, seq_len, head_dim] contiguous.
// K and V may have independent head counts (headsK, headsV); when headsV is
// not specified it defaults to headsK (standard GQA).
static std::vector<float>
referenceSdpa(float qVal, float kVal, float vVal, float maskVal, int64_t batch,
int64_t headsQ, int64_t headsKV, int64_t seqQ, int64_t seqKV,
int64_t headsQ, int64_t headsK, int64_t seqQ, int64_t seqKV,
int64_t headDim, bool isCausal, std::optional<float> scale,
bool enableGqa, bool hasAttnMask) {
bool enableGqa, bool hasAttnMask, int64_t headsV = 0) {
if (headsV == 0)
headsV = headsK;
float s = scale.value_or(1.0f / std::sqrt(static_cast<float>(headDim)));
int64_t outSize = batch * headsQ * seqQ * headDim;
std::vector<float> out(outSize);

for (int64_t b = 0; b < batch; ++b) {
for (int64_t hq = 0; hq < headsQ; ++hq) {
// Map query head to KV head (identity for MHA, grouped for GQA).
int64_t hkv = enableGqa ? hq / (headsQ / headsKV) : hq;
// Map query head to K and V heads independently.
int64_t hk = enableGqa ? hq / (headsQ / headsK) : hq;
int64_t hv = enableGqa ? hq / (headsQ / headsV) : hq;

for (int64_t sq = 0; sq < seqQ; ++sq) {
// Compute attention scores: dot(Q[b,hq,sq,:], K[b,hkv,sk,:]) * scale.
// Compute attention scores: dot(Q[b,hq,sq,:], K[b,hk,sk,:]) * scale.
// Since Q and K are constant-filled, dot product = qVal * kVal *
// headDim for every (sq, sk) pair. We still compute per-element to
// handle causal/mask variations correctly.
Expand Down Expand Up @@ -166,7 +171,7 @@ referenceSdpa(float qVal, float kVal, float vVal, float maskVal, int64_t batch,
scores[sk] /= sumExp;

// Output: weighted sum of V rows.
// V[b, hkv, sk, d] = vVal for all elements.
// V[b, hv, sk, d] = vVal for all elements.
for (int64_t d = 0; d < headDim; ++d) {
float val = 0.0f;
for (int64_t sk = 0; sk < seqKV; ++sk)
Expand All @@ -181,22 +186,29 @@ referenceSdpa(float qVal, float kVal, float vVal, float maskVal, int64_t batch,

// Build a graph that runs scaled dot-product attention on Q, K, V tensors.
// Shape convention: [batch, heads, seq_len, head_dim].
// K and V may have independent head counts; when headsV is 0 (default) it
// uses headsK (standard GQA where H_k == H_v).
static void executeSdpa(Handle &handle, DataType dt, int64_t batch,
int64_t headsQ, int64_t headsKV, int64_t seqQ,
int64_t headsQ, int64_t headsK, int64_t seqQ,
int64_t seqKV, int64_t headDim, bool isCausal = false,
std::optional<float> scale = std::nullopt,
bool enableGqa = false, bool hasAttnMask = false,
float dropoutP = 0.0f) {
float dropoutP = 0.0f, int64_t headsV = 0) {
if (headsV == 0)
headsV = headsK;

// attn_mask and is_causal are mutually exclusive: is_causal internally
// applies a causal mask, making an explicit mask contradictory.
REQUIRE(!(hasAttnMask && isCausal));

if (enableGqa) {
// GQA constraint: query heads must be a multiple of KV heads.
REQUIRE(headsQ % headsKV == 0);
// GQA constraint: query heads must be a multiple of both K and V heads.
REQUIRE(headsQ % headsK == 0);
REQUIRE(headsQ % headsV == 0);
} else {
// Standard MHA: query and KV head counts must match.
REQUIRE(headsQ == headsKV);
// Standard MHA: query, K, and V head counts must all match.
REQUIRE(headsQ == headsK);
REQUIRE(headsQ == headsV);
}

std::string causalSuffix = isCausal ? "_causal" : "";
Expand All @@ -209,9 +221,10 @@ static void executeSdpa(Handle &handle, DataType dt, int64_t batch,

auto graph = std::make_shared<Graph>();
graph
->setName(std::format("sdpa_b{}hq{}hkv{}sq{}skv{}d{}{}{}{}{}{}", batch,
headsQ, headsKV, seqQ, seqKV, headDim, causalSuffix,
maskSuffix, gqaSuffix, scaleSuffix, dropoutSuffix))
->setName(std::format("sdpa_b{}hq{}hk{}hv{}sq{}skv{}d{}{}{}{}{}{}", batch,
headsQ, headsK, headsV, seqQ, seqKV, headDim,
causalSuffix, maskSuffix, gqaSuffix, scaleSuffix,
dropoutSuffix))
.setIODataType(dt)
.setIntermediateDataType(dt);

Expand All @@ -222,15 +235,15 @@ static void executeSdpa(Handle &handle, DataType dt, int64_t batch,
auto qT =
graph->tensor(TensorAttr().setName("q").setDim(qDim).setStride(qStride));

// K: [batch, headsKV, seqKV, headDim]
std::vector<int64_t> kDim = {batch, headsKV, seqKV, headDim};
// K: [batch, headsK, seqKV, headDim]
std::vector<int64_t> kDim = {batch, headsK, seqKV, headDim};
auto kStride =
generateStrideFromDim(kDim, getContiguousStrideOrder(kDim.size()));
auto kT =
graph->tensor(TensorAttr().setName("k").setDim(kDim).setStride(kStride));

// V: [batch, headsKV, seqKV, headDim]
std::vector<int64_t> vDim = {batch, headsKV, seqKV, headDim};
// V: [batch, headsV, seqKV, headDim]
std::vector<int64_t> vDim = {batch, headsV, seqKV, headDim};
auto vStride =
generateStrideFromDim(vDim, getContiguousStrideOrder(vDim.size()));
auto vT =
Expand Down Expand Up @@ -307,8 +320,8 @@ static void executeSdpa(Handle &handle, DataType dt, int64_t batch,
constexpr float kInitMask = -1.0f;

auto expected = referenceSdpa(kInitQ, kInitK, kInitV, kInitMask, batch,
headsQ, headsKV, seqQ, seqKV, headDim, isCausal,
scale, enableGqa, hasAttnMask);
headsQ, headsK, seqQ, seqKV, headDim, isCausal,
scale, enableGqa, hasAttnMask, headsV);

// f16 has ~3 decimal digits of precision; use a tolerance that accounts
// for accumulation error across the softmax and weighted-sum steps.
Expand Down
Loading