Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions lib/Dialect/Rotom/IR/RotomAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

Expand Down Expand Up @@ -102,48 +103,48 @@ static FailureOr<LayoutData> preprocessLayoutData(ArrayAttr dims, int64_t n,
data.pieceIndex.insert(data.pieceIndex.begin() + data.ctPrefixLen, gapIdx);
}

llvm::DenseSet<int64_t> seenDim;
bool allUnique = true;
for (const DimAttr& d : data.traversalDims) {
if (seenDim.contains(d.getDim())) {
allUnique = false;
break;
}
seenDim.insert(d.getDim());
return data;
}

} // namespace

static LogicalResult verifyLayoutRolls(
ArrayAttr dims, DenseI64ArrayAttr rolls,
function_ref<InFlightDiagnostic()> emitError) {
if (!rolls) return success();
ArrayRef<int64_t> r = rolls.asArrayRef();
if (r.empty()) return success();
if (r.size() % 2 != 0) {
return emitError() << "rolls must contain an even number of integers "
"(pairs of dim indices)";
}
if (allUnique && data.traversalDims.size() > 1) {
llvm::SmallVector<std::pair<int64_t, int64_t>> byDim;
byDim.reserve(data.traversalDims.size());
for (int64_t i = 0; i < static_cast<int64_t>(data.traversalDims.size());
++i) {
byDim.push_back({data.traversalDims[i].getDim(), i});

for (size_t i = 0; i < r.size(); i += 2) {
const int64_t ti = r[i];
const int64_t tj = r[i + 1];
if (ti == tj) {
return emitError() << "each roll must use two distinct dim indices";
}
llvm::sort(byDim,
[](const auto& a, const auto& b) { return a.first < b.first; });

llvm::SmallVector<DimAttr> reorderedTraversal;
reorderedTraversal.reserve(data.traversalDims.size());
llvm::SmallVector<int64_t> oldToNew(data.traversalDims.size(), 0);
for (int64_t newIdx = 0; newIdx < static_cast<int64_t>(byDim.size());
++newIdx) {
const int64_t oldIdx = byDim[newIdx].second;
oldToNew[oldIdx] = newIdx;
reorderedTraversal.push_back(data.traversalDims[oldIdx]);
if (ti < 0 || tj < 0 || ti >= static_cast<int64_t>(dims.size()) ||
tj >= static_cast<int64_t>(dims.size())) {
return emitError() << "roll dim index out of bounds for dims list";
}
data.traversalDims = std::move(reorderedTraversal);

for (size_t p = 0; p < data.pieces.size(); ++p) {
if (data.pieces[p] == LayoutPieceKind::Traversal) {
data.pieceIndex[p] = oldToNew[data.pieceIndex[p]];
}
auto di = dyn_cast<DimAttr>(dims[ti]);
auto dj = dyn_cast<DimAttr>(dims[tj]);
if (!di || !dj) {
return emitError() << "roll indices must refer to #rotom.dim entries";
}
if (di.getSize() != dj.getSize()) {
return emitError() << "rolled dims must have the same extent (size)";
}
if (di.isGap() || dj.isGap() || di.isReplicate() || dj.isReplicate()) {
return emitError() << "rolls may only reference non-sentinel traversal "
"dims (dim >= 0)";
}
}

return data;
return success();
}

} // namespace

LogicalResult DimAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t dim, int64_t size, int64_t stride) {
if (dim < -2) {
Expand All @@ -164,7 +165,8 @@ FailureOr<LayoutData> preprocessLayoutAttr(LayoutAttr layout) {
}

LogicalResult LayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayAttr dims, int64_t n) {
ArrayAttr dims, int64_t n,
DenseI64ArrayAttr rolls) {
if (n <= 0) {
return emitError() << "`n` must be > 0, got " << n;
}
Expand All @@ -173,6 +175,8 @@ LogicalResult LayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "`dims` must be an array of `#rotom.dim<...>`";
}

if (failed(verifyLayoutRolls(dims, rolls, emitError))) return failure();

MLIRContext* ctx = dims.getContext();
std::vector<DimAttr> ctDims;
std::vector<DimAttr> slotDims;
Expand Down Expand Up @@ -253,6 +257,11 @@ LogicalResult LayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

LayoutAttr LayoutAttr::get(MLIRContext* context, ArrayAttr dims, int64_t n) {
return get(context, dims, n,
DenseI64ArrayAttr::get(context, ArrayRef<int64_t>{}));
}

} // namespace rotom
} // namespace heir
} // namespace mlir
14 changes: 13 additions & 1 deletion lib/Dialect/Rotom/IR/RotomAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,26 @@ def Rotom_LayoutAttr : Rotom_Attr<"Layout", "layout"> {
For tensor_ext materialization, the **first** entry in `dims` is the
ciphertext side of Rotom's `;` split (one piece); remaining entries are
in-slot. See [Section 4.2 of the Rotom paper](https://eprint.iacr.org/2025/1319.pdf).

Optional **rolls** encode a `roll(i,j)` metadata object: each pair `(i, j)`
indexes into the `dims` array (the flattened `ct_dims + slot_dims` list) and
uses modular addition to modify the indices of `dims[i]` by the indices of
`dims[j]`.
}];

let parameters = (ins
"::mlir::ArrayAttr":$dims,
"int64_t":$n
"int64_t":$n,
OptionalParameter<"::mlir::DenseI64ArrayAttr">:$rolls
);

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Layout with no `roll(i,j)` metadata (empty rolls storage).
static ::mlir::heir::rotom::LayoutAttr get(::mlir::MLIRContext *context,
::mlir::ArrayAttr dims, int64_t n);
}];
}

#endif // LIB_DIALECT_ROTOM_IR_ROTOMATTRIBUTES_TD_
125 changes: 97 additions & 28 deletions lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <string>

#include "lib/Dialect/Rotom/IR/RotomAttributes.h"
Expand All @@ -10,19 +11,43 @@
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

namespace mlir::heir::rotom {
namespace {

/// Maps a `#rotom.dim` from the layout's `dims` list to its iterator index `i*`
/// after preprocessing (match logical axis, size, and stride).
static FailureOr<int64_t> traversalIndexForRotomDim(
const SmallVector<DimAttr>& traversalDims, DimAttr want) {
for (int64_t i = 0; i < static_cast<int64_t>(traversalDims.size()); ++i) {
if (traversalDims[i].getDim() == want.getDim() &&
traversalDims[i].getSize() == want.getSize() &&
traversalDims[i].getStride() == want.getStride()) {
return i;
}
}
return failure();
}

static std::string modExpr(llvm::StringRef expr, int64_t mod) {
std::string out;
llvm::raw_string_ostream os(out);
os << "(" << expr << " - " << mod << " * floor((" << expr << ") / " << mod
<< "))";
return out;
}

static LogicalResult emitSegmentAddress(
llvm::raw_ostream& os, bool& firstTerm, ArrayRef<LayoutPieceKind> pieces,
ArrayRef<int64_t> pieceIndex, const SmallVector<DimAttr>& traversalDims,
const SmallVector<DimAttr>& gapDims,
const SmallVector<DimAttr>& replicationDims,
int64_t numActiveTraversalComponents, size_t segStart, size_t segEnd,
bool foldGapVarsToZero) {
bool foldGapVarsToZero, ArrayRef<int64_t> rolls, ArrayAttr rotomDims,
bool isSlotLine) {
llvm::SmallVector<int64_t> suffixCoeff(pieces.size(), 0);
int64_t suffix = 1;
for (size_t p = segEnd; p > segStart;) {
Expand All @@ -40,39 +65,77 @@ static LogicalResult emitSegmentAddress(
d = replicationDims[pieceIndex[p]];
break;
}

suffix *= d.getSize();
}

auto emitTerm = [&](int64_t coeff, llvm::StringRef var) -> LogicalResult {
auto emitTerm = [&](int64_t coeff, llvm::StringRef expr) -> LogicalResult {
if (coeff == 0) return failure();
if (!firstTerm) os << " + ";
firstTerm = false;
if (coeff == 1) {
os << var;
if (firstTerm) {
if (coeff < 0) os << "-";
firstTerm = false;
} else {
os << coeff << " * " << var;
os << (coeff < 0 ? " - " : " + ");
}
const int64_t absCoeff = std::llabs(coeff);
if (absCoeff == 1) {
os << expr;
} else {
os << absCoeff << " * " << expr;
}
return success();
};

llvm::DenseMap<int64_t, int64_t> traversalCoeff;
llvm::DenseMap<int64_t, int64_t> gapCoeff;
llvm::DenseMap<int64_t, int64_t> replicationCoeff;
for (size_t p = segStart; p < segEnd; ++p) {
const int64_t coeff = suffixCoeff[p];

if (pieces[p] == LayoutPieceKind::Traversal) {
const int64_t ti = pieceIndex[p];
if (traversalDims[ti].getSize() == 1) continue;
traversalCoeff[ti] = coeff;
} else if (pieces[p] == LayoutPieceKind::Gap) {
if (pieces[p] == LayoutPieceKind::Gap) {
if (foldGapVarsToZero) continue;
const int64_t gk = pieceIndex[p];
gapCoeff[gk] = coeff;
} else {
const int64_t ek = pieceIndex[p];
replicationCoeff[ek] = coeff;
gapCoeff[pieceIndex[p]] = coeff;
} else if (pieces[p] == LayoutPieceKind::Replication) {
replicationCoeff[pieceIndex[p]] = coeff;
}
}

llvm::DenseMap<int64_t, int64_t> traversalCoeff;
for (size_t p = segStart; p < segEnd; ++p) {
if (pieces[p] != LayoutPieceKind::Traversal) continue;
const int64_t ti = pieceIndex[p];
if (traversalDims[ti].getSize() == 1) continue;
traversalCoeff[ti] = suffixCoeff[p];
}

llvm::SmallVector<std::string> traversalExprs;
traversalExprs.reserve(traversalDims.size());
for (int64_t i = 0; i < static_cast<int64_t>(traversalDims.size()); ++i) {
traversalExprs.push_back("i" + std::to_string(i));
}

// Apply roll(a,b) transforms left-to-right:
// t_a <- (t_a - t_b) mod extent(a).
if (isSlotLine && !rolls.empty()) {
if (!rotomDims || rolls.size() % 2 != 0) return failure();
for (size_t i = 0; i < rolls.size(); i += 2) {
const int64_t fromIdx = rolls[i];
const int64_t toIdx = rolls[i + 1];
if (fromIdx < 0 || toIdx < 0 ||
fromIdx >= static_cast<int64_t>(rotomDims.size()) ||
toIdx >= static_cast<int64_t>(rotomDims.size())) {
return failure();
}
auto fromDim = dyn_cast<DimAttr>(rotomDims[fromIdx]);
auto toDim = dyn_cast<DimAttr>(rotomDims[toIdx]);
if (!fromDim || !toDim) return failure();
FailureOr<int64_t> maybeFromTrav =
traversalIndexForRotomDim(traversalDims, fromDim);
FailureOr<int64_t> maybeToTrav =
traversalIndexForRotomDim(traversalDims, toDim);
if (failed(maybeFromTrav) || failed(maybeToTrav)) return failure();
const int64_t fromTrav = *maybeFromTrav;
const int64_t toTrav = *maybeToTrav;
std::string diffExpr =
"(" + traversalExprs[fromTrav] + " - " + traversalExprs[toTrav] + ")";
traversalExprs[fromTrav] = modExpr(diffExpr, fromDim.getSize());
}
}

Expand All @@ -81,7 +144,7 @@ static LogicalResult emitSegmentAddress(
if (traversalDims[oldIdx].getSize() == 1) continue;
auto it = traversalCoeff.find(oldIdx);
if (it != traversalCoeff.end()) {
if (failed(emitTerm(it->second, "i" + std::to_string(oldIdx))))
if (failed(emitTerm(it->second, traversalExprs[oldIdx])))
return failure();
}
}
Expand All @@ -107,7 +170,8 @@ static FailureOr<std::string> emitSplitCtSlotIsl(
ArrayRef<int64_t> pieceIndex, const SmallVector<DimAttr>& traversalDims,
const SmallVector<DimAttr>& replicationDims,
const SmallVector<DimAttr>& gapDims, int64_t numTraversalComponents,
int64_t numReplication, int64_t numGap) {
int64_t numReplication, int64_t numGap, ArrayRef<int64_t> rolls,
ArrayAttr rotomDims) {
if (prefix > pieces.size()) return failure();

int64_t numCt = 1;
Expand Down Expand Up @@ -170,7 +234,8 @@ static FailureOr<std::string> emitSplitCtSlotIsl(
if (failed(emitSegmentAddress(os, firstTerm, pieces, pieceIndex,
traversalDims, gapDims, replicationDims,
numTraversalComponents, 0, prefix,
foldGapVarsToZero)))
foldGapVarsToZero, rolls, rotomDims,
/*isSlotLine=*/false)))
return failure();
if (firstTerm) os << "0";

Expand All @@ -180,7 +245,8 @@ static FailureOr<std::string> emitSplitCtSlotIsl(
if (failed(emitSegmentAddress(os, firstTerm, pieces, pieceIndex,
traversalDims, gapDims, replicationDims,
numTraversalComponents, prefix, pieces.size(),
foldGapVarsToZero)))
foldGapVarsToZero, rolls, rotomDims,
/*isSlotLine=*/true)))
return failure();
if (firstTerm) os << "0";

Expand Down Expand Up @@ -215,10 +281,13 @@ static FailureOr<std::string> lowerToIslImpl(LayoutAttr layout) {
const int64_t numReplication =
static_cast<int64_t>(data.replicationDims.size());
const int64_t numGap = static_cast<int64_t>(data.gapDims.size());
return emitSplitCtSlotIsl(data.n, data.ctPrefixLen, data.pieces,
data.pieceIndex, data.traversalDims,
data.replicationDims, data.gapDims,
numTraversalComponents, numReplication, numGap);
DenseI64ArrayAttr rollsAttr = layout.getRolls();
ArrayRef<int64_t> rolls =
rollsAttr ? rollsAttr.asArrayRef() : ArrayRef<int64_t>{};
return emitSplitCtSlotIsl(
data.n, data.ctPrefixLen, data.pieces, data.pieceIndex,
data.traversalDims, data.replicationDims, data.gapDims,
numTraversalComponents, numReplication, numGap, rolls, layout.getDims());
}

} // namespace
Expand Down
Loading
Loading