Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 32 additions & 48 deletions include/fusilli/support/asm_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1853,73 +1853,57 @@ inline ErrorOr<std::string> ReductionNode::emitNodePreAsm() const {
std::string permuteY =
getLayoutConversionOpsAsm(yT, "permute_Y", suffix, /*isInput=*/false);

switch (reductionAttr.getMode()) {
case ReductionAttr::Mode::SUM: {
constexpr std::string_view schema = R"(
constexpr std::string_view kKeepdimReductionSchema = R"(
{0}
{1}
%keepdim_{2} = torch.constant.bool true
%dtype_{2} = torch.constant.none
{3}_{2}_perm = torch.aten.sum.dim_IntList {4}, %reduction_dims_{2}, %keepdim_{2}, %dtype_{2} : {5}, !torch.list<int>, !torch.bool, !torch.none -> {6}
{3}_{2}_perm = {8} {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list<int>, !torch.bool -> {6}
{7}
)";

return std::format(schema,
permuteX, // {0}
dimListOss.str(), // {1}
suffix, // {2}
getResultNamesAsm(), // {3}
getOperandNamesAsm(), // {4}
getOperandTypesAsm(), // {5}
getResultTypesAsm(), // {6}
permuteY // {7}
);
}
case ReductionAttr::Mode::MIN: {
constexpr std::string_view schema = R"(
constexpr std::string_view kKeepdimDtypeReductionSchema = R"(
{0}
{1}
%keepdim_{2} = torch.constant.bool true
{3}_{2}_perm = torch.aten.amin {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list<int>, !torch.bool -> {6}
%dtype_{2} = torch.constant.none
{3}_{2}_perm = {8} {4}, %reduction_dims_{2}, %keepdim_{2}, %dtype_{2} : {5}, !torch.list<int>, !torch.bool, !torch.none -> {6}
{7}
)";

return std::format(schema,
permuteX, // {0}
dimListOss.str(), // {1}
suffix, // {2}
getResultNamesAsm(), // {3}
getOperandNamesAsm(), // {4}
getOperandTypesAsm(), // {5}
getResultTypesAsm(), // {6}
permuteY // {7}
);
#define FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, SCHEMA, OPIR) \
case ReductionAttr::Mode::MODE: { \
return std::format(SCHEMA, permuteX, /* {0} */ \
dimListOss.str(), /* {1} */ \
suffix, /* {2} */ \
getResultNamesAsm(), /* {3} */ \
getOperandNamesAsm(), /* {4} */ \
getOperandTypesAsm(), /* {5} */ \
getResultTypesAsm(), /* {6} */ \
permuteY, /* {7} */ \
#OPIR /* {8} */ \
); \
}
case ReductionAttr::Mode::MAX: {
constexpr std::string_view schema = R"(
{0}
{1}
%keepdim_{2} = torch.constant.bool true
{3}_{2}_perm = torch.aten.amax {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list<int>, !torch.bool -> {6}
{7}
)";

return std::format(schema,
permuteX, // {0}
dimListOss.str(), // {1}
suffix, // {2}
getResultNamesAsm(), // {3}
getOperandNamesAsm(), // {4}
getOperandTypesAsm(), // {5}
getResultTypesAsm(), // {6}
permuteY // {7}
);
}
#define FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MODE, OPIR) \
FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, kKeepdimReductionSchema, OPIR)

#define FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER(MODE, OPIR) \
FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, kKeepdimDtypeReductionSchema, OPIR)

switch (reductionAttr.getMode()) {
FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER(SUM,
torch.aten.sum.dim_IntList)
FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MIN, torch.aten.amin)
FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MAX, torch.aten.amax)
default:
return error(ErrorCode::InternalError, "Unsupported reduction mode");
}
}

#undef FUSILLI_DECLARE_REDUCTION_EMITTER
#undef FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER
#undef FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER

//===----------------------------------------------------------------------===//
//
// CustomOpNode ASM Emitter Methods
Expand Down
8 changes: 6 additions & 2 deletions samples/reduction/reduction_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ TEST_CASE("Reduction ops", "[reduction][graph]") {
const auto xDims = std::vector<int64_t>{2, 16, 8, 8};
const auto yDims = std::vector<int64_t>{2, 16, 1, 1};

const auto mode = GENERATE(ReductionAttr::Mode::SUM, ReductionAttr::Mode::MIN,
ReductionAttr::Mode::MAX);
// clang-format off
const auto mode = GENERATE(
ReductionAttr::Mode::SUM,
ReductionAttr::Mode::MIN,
ReductionAttr::Mode::MAX);
// clang-format on

auto execute = [&]<typename T>(Handle &handle, DataType dt, T initValue) {
// Create graph.
Expand Down
Loading