diff --git a/include/fusilli.h b/include/fusilli.h index 6dca114d..2f7e3023 100644 --- a/include/fusilli.h +++ b/include/fusilli.h @@ -30,27 +30,29 @@ #include "fusilli/support/target_platform.h" // IWYU pragma: export // Attributes / Types: -#include "fusilli/attributes/attributes.h" // IWYU pragma: export -#include "fusilli/attributes/common.h" // IWYU pragma: export -#include "fusilli/attributes/conv_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/custom_op_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/layernorm_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/matmul_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/pointwise_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/reduction_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/rmsnorm_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/tensor_attributes.h" // IWYU pragma: export -#include "fusilli/attributes/types.h" // IWYU pragma: export +#include "fusilli/attributes/attributes.h" // IWYU pragma: export +#include "fusilli/attributes/blocked_matmul_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/common.h" // IWYU pragma: export +#include "fusilli/attributes/conv_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/custom_op_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/layernorm_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/matmul_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/pointwise_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/reduction_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/rmsnorm_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/tensor_attributes.h" // IWYU pragma: export +#include "fusilli/attributes/types.h" // IWYU pragma: export // Nodes: -#include "fusilli/node/conv_node.h" // IWYU pragma: export -#include "fusilli/node/custom_op_node.h" // IWYU pragma: export -#include "fusilli/node/layernorm_node.h" // IWYU pragma: export -#include "fusilli/node/matmul_node.h" // IWYU pragma: export -#include "fusilli/node/node.h" // IWYU pragma: export -#include "fusilli/node/pointwise_node.h" // IWYU pragma: export -#include "fusilli/node/reduction_node.h" // IWYU pragma: export -#include "fusilli/node/rmsnorm_node.h" // IWYU pragma: export +#include "fusilli/node/blocked_matmul_node.h" // IWYU pragma: export +#include "fusilli/node/conv_node.h" // IWYU pragma: export +#include "fusilli/node/custom_op_node.h" // IWYU pragma: export +#include "fusilli/node/layernorm_node.h" // IWYU pragma: export +#include "fusilli/node/matmul_node.h" // IWYU pragma: export +#include "fusilli/node/node.h" // IWYU pragma: export +#include "fusilli/node/pointwise_node.h" // IWYU pragma: export +#include "fusilli/node/reduction_node.h" // IWYU pragma: export +#include "fusilli/node/rmsnorm_node.h" // IWYU pragma: export // Backend: #include "fusilli/backend/backend.h" // IWYU pragma: export diff --git a/include/fusilli/attributes/blocked_matmul_attributes.h b/include/fusilli/attributes/blocked_matmul_attributes.h new file mode 100644 index 00000000..b24841e9 --- /dev/null +++ b/include/fusilli/attributes/blocked_matmul_attributes.h @@ -0,0 +1,48 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains attributes (compile-time constant metadata) for +// blocked matmul nodes. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_ATTRIBUTES_BLOCKED_MATMUL_ATTRIBUTES_H +#define FUSILLI_ATTRIBUTES_BLOCKED_MATMUL_ATTRIBUTES_H + +#include "fusilli/attributes/attributes.h" +#include "fusilli/attributes/tensor_attributes.h" + +#include +#include +#include + +namespace fusilli { + +class BlockedMatmulAttr : public AttributesCRTP { +public: + // Names for Tensor Inputs and Outputs. + enum class InputNames : uint8_t { LHS, RHS }; + enum class OutputNames : uint8_t { RESULT }; + + std::unordered_map> inputs; + std::unordered_map> outputs; + + // Setters: + FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BlockedMatmulAttr, InputNames, LHS) + FUSILLI_GENERIC_INPUT_TENSOR_SETTER(BlockedMatmulAttr, InputNames, RHS) + FUSILLI_GENERIC_OUTPUT_TENSOR_SETTER(BlockedMatmulAttr, OutputNames, RESULT) + + // Getters: + FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, LHS) + FUSILLI_GENERIC_INPUT_TENSOR_GETTER(InputNames, RHS) + FUSILLI_GENERIC_OUTPUT_TENSOR_GETTER(OutputNames, RESULT) +}; + +} // namespace fusilli + +#endif // FUSILLI_ATTRIBUTES_BLOCKED_MATMUL_ATTRIBUTES_H diff --git a/include/fusilli/graph/graph.h b/include/fusilli/graph/graph.h index 924db182..dd88b9a1 100644 --- a/include/fusilli/graph/graph.h +++ b/include/fusilli/graph/graph.h @@ -14,6 +14,7 @@ #ifndef FUSILLI_GRAPH_GRAPH_H #define FUSILLI_GRAPH_GRAPH_H +#include "fusilli/attributes/blocked_matmul_attributes.h" #include "fusilli/attributes/common.h" #include "fusilli/attributes/conv_attributes.h" #include "fusilli/attributes/custom_op_attributes.h" @@ -30,6 +31,7 @@ #include "fusilli/backend/compile_session.h" #include "fusilli/backend/handle.h" #include "fusilli/graph/context.h" +#include "fusilli/node/blocked_matmul_node.h" #include "fusilli/node/conv_node.h" #include "fusilli/node/custom_op_node.h" #include "fusilli/node/layernorm_node.h" @@ -277,6 +279,10 @@ class Graph : public INode { std::shared_ptr matmul(const std::shared_ptr &a, const std::shared_ptr &b, MatmulAttr &attributes); + std::shared_ptr + blockedMatmul(const std::shared_ptr &lhs, + const std::shared_ptr &rhs, + BlockedMatmulAttr &attributes); std::shared_ptr pointwise(const std::shared_ptr &in, PointwiseAttr &attributes); @@ -849,6 +855,33 @@ Graph::matmul(const std::shared_ptr &a, return c; } +// Create a BlockedMatmulNode, populate it with the specified attributes, create +// output tensors and add the node to the graph's sub nodes. +inline std::shared_ptr +Graph::blockedMatmul(const std::shared_ptr &lhs, + const std::shared_ptr &rhs, + BlockedMatmulAttr &bmAttr) { + if (bmAttr.getName().empty()) + bmAttr.setName("blocked_matmul_" + std::to_string(subNodes_.size())); + if (lhs && lhs->getName().empty()) + lhs->setName(bmAttr.getName() + "_LHS"); + if (rhs && rhs->getName().empty()) + rhs->setName(bmAttr.getName() + "_RHS"); + + FUSILLI_LOG_LABEL_ENDL("INFO: Adding BlockedMatmulNode '" << bmAttr.getName() + << "' to Graph"); + + bmAttr.setLHS(lhs).setRHS(rhs); + + auto out = outputTensor(bmAttr.getName() + "_RESULT"); + bmAttr.setRESULT(out); + + subNodes_.emplace_back( + std::make_unique(std::move(bmAttr), context)); + + return out; +} + // Create a PointwiseNode for single operand cases (e.g. RELU), populate it with // the specified attributes, create output tensors and add the node to the // graph's sub nodes. diff --git a/include/fusilli/node/blocked_matmul_node.h b/include/fusilli/node/blocked_matmul_node.h new file mode 100644 index 00000000..68bc0b2c --- /dev/null +++ b/include/fusilli/node/blocked_matmul_node.h @@ -0,0 +1,170 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains definitions for the blocked matmul node +// `BlockedMatmulNode`. +// +// Blocked matmul operates on 4D tiled tensors: +// LHS logical: [M0, K0, M1, K1] +// RHS logical: [K0, N0, K1, N1] +// OUT: [M0, N0, M1, N1] +// +// When RHS is specified with transposed strides (physical layout +// [N0, K0, N1, K1]), this lowers to `linalg.mmt4d`. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_NODE_BLOCKED_MATMUL_NODE_H +#define FUSILLI_NODE_BLOCKED_MATMUL_NODE_H + +#include "fusilli/attributes/blocked_matmul_attributes.h" +#include "fusilli/attributes/tensor_attributes.h" +#include "fusilli/graph/context.h" +#include "fusilli/node/node.h" +#include "fusilli/support/logging.h" + +#include +#include +#include +#include +#include + +namespace fusilli { + +//===----------------------------------------------------------------------===// +// Helper functions for blocked matmul nodes. +//===----------------------------------------------------------------------===// + +// Infer the output shape of a blocked matmul operation. +// LHS [M0, K0, M1, K1] x RHS [K0, N0, K1, N1] -> OUT [M0, N0, M1, N1] +inline std::vector +getBlockedMatmulInferredOutputShape(const std::vector &lhsDim, + const std::vector &rhsDim) { + assert(lhsDim.size() == 4 && "LHS must be rank 4"); + assert(rhsDim.size() == 4 && "RHS must be rank 4"); + return {lhsDim[0], rhsDim[1], lhsDim[2], rhsDim[3]}; +} + +//===----------------------------------------------------------------------===// +// Blocked matmul node. +//===----------------------------------------------------------------------===// + +class BlockedMatmulNode : public NodeCRTP { +public: + BlockedMatmulAttr blockedMatmulAttr; + + BlockedMatmulNode(BlockedMatmulAttr &&attr, const Context &ctx) + : NodeCRTP(ctx), blockedMatmulAttr(std::move(attr)) {} + + // ASM emitter methods. + std::string emitNodePreAsm() const override final; + + const std::string &getName() const override final { + return blockedMatmulAttr.getName(); + } + Type getType() const override final { return Type::BlockedMatmul; } + + ErrorObject preValidateNode() const override final { + FUSILLI_LOG_LABEL_ENDL("INFO: Pre-Validating BlockedMatmulNode '" + << blockedMatmulAttr.getName() << "'"); + + auto lhsT = blockedMatmulAttr.getLHS(); + auto rhsT = blockedMatmulAttr.getRHS(); + auto outT = blockedMatmulAttr.getRESULT(); + + FUSILLI_RETURN_ERROR_IF(!lhsT, ErrorCode::AttributeNotSet, + "BlockedMatmul input tensor LHS not set"); + FUSILLI_RETURN_ERROR_IF(!rhsT, ErrorCode::AttributeNotSet, + "BlockedMatmul input tensor RHS not set"); + FUSILLI_RETURN_ERROR_IF(!outT, ErrorCode::AttributeNotSet, + "BlockedMatmul output tensor OUT not set"); + + size_t lhsRank = lhsT->getDim().size(); + size_t rhsRank = rhsT->getDim().size(); + FUSILLI_RETURN_ERROR_IF(lhsRank != 4, ErrorCode::InvalidAttribute, + "BlockedMatmul LHS must have rank 4, got " + + std::to_string(lhsRank)); + FUSILLI_RETURN_ERROR_IF(rhsRank != 4, ErrorCode::InvalidAttribute, + "BlockedMatmul RHS must have rank 4, got " + + std::to_string(rhsRank)); + + // K dimensions must match: + // LHS logical [M0, K0, M1, K1], RHS logical [K0, N0, K1, N1] + // LHS[1] == RHS[0] (K0) and LHS[3] == RHS[2] (K1) + const auto &lhsDim = lhsT->getDim(); + const auto &rhsDim = rhsT->getDim(); + FUSILLI_RETURN_ERROR_IF( + lhsDim[1] != rhsDim[0], ErrorCode::InvalidAttribute, + "BlockedMatmul K0 mismatch: LHS[1]=" + std::to_string(lhsDim[1]) + + ", RHS[0]=" + std::to_string(rhsDim[0])); + FUSILLI_RETURN_ERROR_IF( + lhsDim[3] != rhsDim[2], ErrorCode::InvalidAttribute, + "BlockedMatmul K1 mismatch: LHS[3]=" + std::to_string(lhsDim[3]) + + ", RHS[2]=" + std::to_string(rhsDim[2])); + + // RHS must be transposed: logical [K0, N0, K1, N1] must have physical + // layout [N0, K0, N1, K1] for linalg.mmt4d. This corresponds to + // logical-to-physical permutation [1, 0, 3, 2]. + std::vector rhsPerm = rhsT->getLogicalToPhysicalPermuteOrder(); + std::vector expectedPerm = {1, 0, 3, 2}; + FUSILLI_RETURN_ERROR_IF( + rhsPerm != expectedPerm, ErrorCode::NotImplemented, + "BlockedMatmul only supports RHS with transposed physical layout " + "[N0, K0, N1, K1]. Non-transposed RHS is not yet supported"); + + return ok(); + } + + ErrorObject inferPropertiesNode() override final { + FUSILLI_LOG_LABEL_ENDL("INFO: Inferring properties for BlockedMatmulNode '" + << blockedMatmulAttr.getName() << "'"); + + blockedMatmulAttr.fillFromContext(context); + + auto lhsT = blockedMatmulAttr.getLHS(); + auto rhsT = blockedMatmulAttr.getRHS(); + auto outT = blockedMatmulAttr.getRESULT(); + + const auto &outDim = outT->getDim(); + const auto &outStride = outT->getStride(); + + if (outDim.empty()) + outT->setDim( + getBlockedMatmulInferredOutputShape(lhsT->getDim(), rhsT->getDim())); + + if (outStride.empty()) { + outT->setStride(generateStrideFromDim( + outT->getDim(), getContiguousStrideOrder(outT->getDim().size()))); + } + + return ok(); + } + + ErrorObject postValidateNode() const override final { + FUSILLI_LOG_LABEL_ENDL("INFO: Post-Validating BlockedMatmulNode '" + << blockedMatmulAttr.getName() << "'"); + + auto outT = blockedMatmulAttr.getRESULT(); + FUSILLI_RETURN_ERROR_IF(outT->getDim().size() != 4, + ErrorCode::InvalidAttribute, + "BlockedMatmul OUT must have rank 4"); + + FUSILLI_RETURN_ERROR_IF( + outT->getDim() != getBlockedMatmulInferredOutputShape( + blockedMatmulAttr.getLHS()->getDim(), + blockedMatmulAttr.getRHS()->getDim()), + ErrorCode::InvalidAttribute, + "BlockedMatmul OUT dimensions do not match expected shape"); + + return ok(); + } +}; + +} // namespace fusilli + +#endif // FUSILLI_NODE_BLOCKED_MATMUL_NODE_H diff --git a/include/fusilli/node/node.h b/include/fusilli/node/node.h index 2f03ee4b..ec7011bb 100644 --- a/include/fusilli/node/node.h +++ b/include/fusilli/node/node.h @@ -37,6 +37,7 @@ class INode { LayerNorm, RmsNorm, Matmul, + BlockedMatmul, Reduction, Custom, }; diff --git a/include/fusilli/support/asm_emitter.h b/include/fusilli/support/asm_emitter.h index 78f50dc5..7b9093a6 100644 --- a/include/fusilli/support/asm_emitter.h +++ b/include/fusilli/support/asm_emitter.h @@ -296,6 +296,27 @@ inline std::string TensorAttr::getValueNameAsm(bool isOutputAliased) const { return "%" + filtered + (isOutputAliased ? "_" : ""); } +// Emits a builtin ranked tensor type in MLIR assembly format. +// Uses physical dims (the actual memory layout) of the tensor. +// +// Example: +// tensor with physical dims [16, 16, 8, 4] and DataType::Float +// --> "tensor<16x16x8x4xf32>" +inline std::string +getBuiltinTensorTypeAsm(const std::shared_ptr &tensor) { + assert(!tensor->getDim().empty() && + "getBuiltinTensorTypeAsm expects non-empty dims"); + assert(tensor->getDataType() != DataType::NotSet && + "getBuiltinTensorTypeAsm expects a valid data type"); + + std::ostringstream oss; + oss << "tensor<"; + for (auto dim : tensor->getPhysicalDim()) + oss << dim << "x"; + oss << kDataTypeToMlirTypeAsm.at(tensor->getDataType()) << ">"; + return oss.str(); +} + //===----------------------------------------------------------------------===// // // Graph ASM Emitter Methods @@ -1350,6 +1371,75 @@ inline std::string MatmulNode::emitNodePreAsm() const { return output; } +//===----------------------------------------------------------------------===// +// +// BlockedMatmulNode ASM Emitter Methods +// +//===----------------------------------------------------------------------===// + +// Emits blocked matmul as linalg.mmt4d with torch_c casts. +// +// The emitter operates on physical tensor layouts directly (no permute ops). +// Function arguments are in physical layout; torch_c casts bridge to builtin +// tensors for linalg.mmt4d, then cast the result back to torch. +// +// Generated MLIR pattern: +// %lhs = torch_c.to_builtin_tensor %arg_lhs : !torch.vtensor<[phys],dt> -> +// tensor %rhs = torch_c.to_builtin_tensor %arg_rhs : +// !torch.vtensor<[phys],dt> -> tensor %cst = arith.constant 0.0 : dt +// %empty = tensor.empty() : tensor +// %fill = linalg.fill ins(%cst : dt) outs(%empty : ...) -> ... +// %mmt4d = linalg.mmt4d ins(%lhs, %rhs : ...) outs(%fill : ...) -> ... +// %result = torch_c.from_builtin_tensor %mmt4d : ... -> +// !torch.vtensor<[phys],dt> +inline std::string BlockedMatmulNode::emitNodePreAsm() const { + auto lhsT = blockedMatmulAttr.getLHS(); + auto rhsT = blockedMatmulAttr.getRHS(); + auto outT = blockedMatmulAttr.getRESULT(); + std::string suffix = blockedMatmulAttr.getName(); + + std::string lhsTorchType = lhsT->getTensorTypeAsm(/*isValueTensor=*/true, + /*useLogicalDims=*/false); + std::string rhsTorchType = rhsT->getTensorTypeAsm(/*isValueTensor=*/true, + /*useLogicalDims=*/false); + std::string outTorchType = outT->getTensorTypeAsm(/*isValueTensor=*/true, + /*useLogicalDims=*/false); + + std::string lhsBuiltinType = getBuiltinTensorTypeAsm(lhsT); + std::string rhsBuiltinType = getBuiltinTensorTypeAsm(rhsT); + std::string outBuiltinType = getBuiltinTensorTypeAsm(outT); + + std::string mlirType = kDataTypeToMlirTypeAsm.at(outT->getDataType()); + + std::string lhsName = lhsT->getValueNameAsm(); + std::string rhsName = rhsT->getValueNameAsm(); + std::string outName = outT->getValueNameAsm(); + + constexpr std::string_view schema = R"( + %{0}_lhs_builtin = torch_c.to_builtin_tensor {1} : {2} -> {3} + %{0}_rhs_builtin = torch_c.to_builtin_tensor {4} : {5} -> {6} + %{0}_cst = arith.constant 0.000000e+00 : {7} + %{0}_empty = tensor.empty() : {8} + %{0}_fill = linalg.fill ins(%{0}_cst : {7}) outs(%{0}_empty : {8}) -> {8} + %{0}_mmt4d = linalg.mmt4d ins(%{0}_lhs_builtin, %{0}_rhs_builtin : {3}, {6}) outs(%{0}_fill : {8}) -> {8} + {9} = torch_c.from_builtin_tensor %{0}_mmt4d : {8} -> {10} + )"; + + return std::format(schema, + suffix, // {0} unique prefix + lhsName, // {1} LHS SSA name + lhsTorchType, // {2} LHS torch type + lhsBuiltinType, // {3} LHS builtin type + rhsName, // {4} RHS SSA name + rhsTorchType, // {5} RHS torch type + rhsBuiltinType, // {6} RHS builtin type + mlirType, // {7} scalar element type + outBuiltinType, // {8} OUT builtin type + outName, // {9} OUT SSA name + outTorchType // {10} OUT torch type + ); +} + //===----------------------------------------------------------------------===// // // PointwiseNode ASM Emitter Methods diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 0a95ea85..079f14d3 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -56,6 +56,16 @@ add_fusilli_samples( Catch2::Catch2WithMain ) +add_fusilli_samples( + PREFIX fusilli_blocked_matmul_samples + SRCS + blocked_matmul/blocked_matmul_basic.cpp + DEPS + libfusilli + libutils + Catch2::Catch2WithMain +) + add_fusilli_samples( PREFIX fusilli_layernorm_samples SRCS diff --git a/samples/blocked_matmul/blocked_matmul_basic.cpp b/samples/blocked_matmul/blocked_matmul_basic.cpp new file mode 100644 index 00000000..485c293c --- /dev/null +++ b/samples/blocked_matmul/blocked_matmul_basic.cpp @@ -0,0 +1,94 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "utils.h" + +#include + +#include +#include +#include +#include +#include +#include + +using namespace fusilli; + +// Blocked matmul (mmt4d): +// LHS [M0, K0, M1, K1] x RHS [K0, N0, K1, N1] -> OUT [M0, N0, M1, N1] +// RHS is transposed: physical layout [N0, K0, N1, K1] +// +// When all inputs are ones, each output element equals K0 * K1 (the +// contraction dimension product). +TEST_CASE("Blocked matmul; LHS [M0,K0,M1,K1], RHS transposed; basic mmt4d", + "[blocked_matmul][graph]") { + int64_t m0 = 4, k0 = 8, m1 = 4, k1 = 2; + int64_t n0 = 6, n1 = 4; + + auto buildNewGraph = [=](const Handle &handle) { + auto graph = std::make_shared(); + graph->setName("blocked_matmul_basic_sample"); + graph->setIODataType(DataType::Float).setComputeDataType(DataType::Float); + + // LHS: logical [m0, k0, m1, k1], contiguous + auto lhsT = graph->tensor(TensorAttr() + .setName("lhs") + .setDim({m0, k0, m1, k1}) + .setStride({k0 * m1 * k1, m1 * k1, k1, 1})); + + // RHS: logical [k0, n0, k1, n1], physical [n0, k0, n1, k1] (transposed) + auto rhsT = graph->tensor(TensorAttr() + .setName("rhs") + .setDim({k0, n0, k1, n1}) + .setStride({n1 * k1, k0 * n1 * k1, 1, k1})); + + auto bmAttr = BlockedMatmulAttr().setName("blocked_matmul"); + auto outT = graph->blockedMatmul(lhsT, rhsT, bmAttr); + outT->setOutput(true); + + FUSILLI_REQUIRE_OK(graph->validate()); + FUSILLI_REQUIRE_OK(graph->compile(handle, /*remove=*/true)); + + return std::make_tuple(graph, lhsT, rhsT, outT); + }; + + FUSILLI_REQUIRE_ASSIGN(Handle handle, Handle::create(kDefaultBackend)); + + auto [graph, lhsT, rhsT, outT] = buildNewGraph(handle); + + // Allocate input buffers (all ones). + FUSILLI_REQUIRE_ASSIGN( + auto lhsBuf, allocateBufferOfType(handle, lhsT, DataType::Float, 1.0f)); + FUSILLI_REQUIRE_ASSIGN( + auto rhsBuf, allocateBufferOfType(handle, rhsT, DataType::Float, 1.0f)); + + // Allocate output buffer (zeros). + FUSILLI_REQUIRE_ASSIGN( + auto outBuf, allocateBufferOfType(handle, outT, DataType::Float, 0.0f)); + + const std::unordered_map, std::shared_ptr> + variantPack = { + {lhsT, lhsBuf}, + {rhsT, rhsBuf}, + {outT, outBuf}, + }; + + FUSILLI_REQUIRE_ASSIGN(auto workspace, + allocateWorkspace(handle, graph->getWorkspaceSize())); + + FUSILLI_REQUIRE_OK(graph->execute(handle, variantPack, workspace)); + + std::vector result; + FUSILLI_REQUIRE_OK(outBuf->read(handle, result)); + + // When LHS and RHS are all ones, each output element = k0 * k1. + float expected = static_cast(k0 * k1); + for (size_t i = 0; i < result.size(); ++i) { + REQUIRE(result[i] == expected); + } +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e0dcc506..9b6207f8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -166,6 +166,7 @@ add_fusilli_lit_tests( lit/test_layernorm_train_asm_emitter_scale_bias_nhwc.cpp lit/test_rmsnorm_infer_asm_emitter_nchw.cpp lit/test_rmsnorm_infer_asm_emitter_scale_nhwc.cpp + lit/test_blocked_matmul_asm_emitter.cpp lit/test_matmul_asm_emitter_basic.cpp lit/test_matmul_asm_emitter_batched.cpp lit/test_matmul_asm_emitter_broadcast_3D.cpp diff --git a/tests/lit/test_blocked_matmul_asm_emitter.cpp b/tests/lit/test_blocked_matmul_asm_emitter.cpp new file mode 100644 index 00000000..f757ff4c --- /dev/null +++ b/tests/lit/test_blocked_matmul_asm_emitter.cpp @@ -0,0 +1,111 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// RUN: %{TEST_EXE} | FileCheck %s --check-prefix=TORCH-CHECK +// RUN: %{TEST_EXE} | iree-compile - \ +// RUN: --iree-hal-target-backends=llvm-cpu \ +// RUN: --iree-llvmcpu-target-cpu=host \ +// RUN: --iree-torch-externalize-transients \ +// RUN: --iree-torch-enable-shape-refinement \ +// RUN: --compile-to=flow | \ +// RUN: FileCheck %s --check-prefix=FLOW-CHECK +// RUN: %{TEST_EXE} stats | FileCheck %s --check-prefix=%{BACKEND}-STATS-CHECK + +// clang-format off +// +// Logical matmul: [128, 64] @ [64, 256] -> [128, 256] +// Tile sizes: M1=8, N1=8, K1=4 +// LHS logical [M0=16, K0=16, M1=8, K1=4], physical same (contiguous) +// RHS logical [K0=16, N0=32, K1=4, N1=8], physical [N0=32, K0=16, N1=8, K1=4] (transposed) +// OUT [M0=16, N0=32, M1=8, N1=8] +// +// TORCH-CHECK: module @module { +// TORCH-CHECK: func.func @main(%{{.+}}: !torch.tensor<[16,32,8,8],f32>, %{{.+}}: !torch.vtensor<[16,16,8,4],f32>, %{{.+}}: !torch.vtensor<[32,16,8,4],f32>) +// TORCH-CHECK: torch_c.to_builtin_tensor +// TORCH-CHECK: torch_c.to_builtin_tensor +// TORCH-CHECK: linalg.mmt4d +// TORCH-CHECK: torch_c.from_builtin_tensor +// TORCH-CHECK: torch.overwrite.tensor.contents +// +// FLOW-CHECK: linalg.mmt4d +// +// AMDGPU-STATS-CHECK: "dispatch-count": 1 +// CPU-STATS-CHECK: "dispatch-count": 1 +// +// clang-format on + +#include + +#include "utils.h" + +#include +#include +#include +#include + +using namespace fusilli; + +static ErrorObject testBlockedMatmulAsmEmitter(const std::string &mode) { + // Logical matmul: [128, 64] @ [64, 256] -> [128, 256] + // Tile sizes: m1=8, n1=8, k1=4 + int64_t m0 = 16, k0 = 16, m1 = 8, k1 = 4; + int64_t n0 = 32, n1 = 8; + + auto graph = std::make_shared(); + graph->setName("blocked_matmul_mmt4d"); + graph->setIODataType(DataType::Float).setComputeDataType(DataType::Float); + + // LHS: logical [m0, k0, m1, k1], contiguous (row-major) + auto lhsT = graph->tensor(TensorAttr() + .setName("arg0_lhs") + .setDim({m0, k0, m1, k1}) + .setStride({k0 * m1 * k1, m1 * k1, k1, 1})); + + // RHS: logical [k0, n0, k1, n1], physical [n0, k0, n1, k1] (transposed) + // Strides encode the physical layout: dim order in memory is [n0, k0, n1, k1] + // stride[0] (k0) = n1 * k1 (k0 moves within an n0-block) + // stride[1] (n0) = k0 * n1 * k1 (n0 is outermost) + // stride[2] (k1) = 1 (k1 is innermost) + // stride[3] (n1) = k1 (n1 comes before k1 innermost) + auto rhsT = graph->tensor(TensorAttr() + .setName("arg1_rhs") + .setDim({k0, n0, k1, n1}) + .setStride({n1 * k1, k0 * n1 * k1, 1, k1})); + + auto bmAttr = BlockedMatmulAttr().setName("blocked_matmul"); + + auto outT = graph->blockedMatmul(lhsT, rhsT, bmAttr); + outT->setName("result").setOutput(true); + + FUSILLI_CHECK_ERROR(graph->validate()); + + if (mode == "default") { + FUSILLI_ASSIGN_OR_RETURN(auto generatedAsm, graph->emitAsm()); + FUSILLI_CHECK_ERROR(checkMlirIndentation(generatedAsm)); + std::cout << generatedAsm << std::endl; + } + + if (mode == "stats") { + FUSILLI_ASSIGN_OR_RETURN(Handle handle, Handle::create(kDefaultBackend)); + FUSILLI_CHECK_ERROR(graph->compile(handle, /*remove=*/true)); + FUSILLI_ASSIGN_OR_RETURN(auto stats, graph->readCompilationCacheFile( + CachedAssetsType::Statistics)); + std::cout << stats << std::endl; + } + + return ok(); +} + +int main(int argc, char **argv) { + std::string mode = (argc > 1) ? argv[1] : "default"; + + auto status = testBlockedMatmulAsmEmitter(mode); + if (isError(status)) { + std::cerr << "Test failed: " << status << std::endl; + return 1; + } + return 0; +}