Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

Expand Down Expand Up @@ -98,11 +102,40 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
// of the ciphertext at that point in the computation, as well as the decision
// variable to track whether to insert a relinearization operation after the
// operation.
opToRunOn->walk([&](Operation* op) {
opToRunOn->walk<WalkOrder::PreOrder>([&](Operation* op) -> WalkResult {
// Skipping inner loop bodies because they will be handled by a ILP solver
// in bottom-up order. But, we still need to create variables for the inner
// loop's results so the outer solver knows about the inner loop.
if (isa<LoopLikeOpInterface>(op) && op != opToRunOn) {
std::string name = uniqueName(op);
if (isSecret(op->getResults(), solver)) {
auto decisionVar = model.AddBinaryVariable("InsertRelin_" + name);
decisionVariables.insert(std::make_pair(op, decisionVar));
}

for (OpResult opResult : op->getOpResults()) {
Value result = opResult;
if (!isSecret(result, solver)) {
continue;
}
std::string varName =
"Degree_" + name + "_" + std::to_string(opResult.getResultNumber());
auto keyBasisVar =
model.AddContinuousVariable(0, MAX_KEY_BASIS_DEGREE, varName);
keyBasisVars.insert(std::make_pair(result, keyBasisVar));

std::string brVarName = varName + "_br";
auto brKeyBasisVar =
model.AddContinuousVariable(0, MAX_KEY_BASIS_DEGREE, brVarName);
beforeRelinVars.insert(std::make_pair(result, brKeyBasisVar));
}
return WalkResult::skip();
}

std::string name = uniqueName(op);

if (isa<ModuleOp>(op)) {
return;
return WalkResult::advance();
}

// skip secret generic op; we decide inside generic op block
Expand Down Expand Up @@ -138,16 +171,30 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
// linearized, though this could be generalized to read the degree from the
// type.
if (op->getNumRegions() == 0) {
return;
return WalkResult::advance();
}

LLVM_DEBUG(llvm::dbgs()
<< "Handling block arguments for " << op->getName() << "\n");
for (Region& region : op->getRegions()) {
for (Block& block : region.getBlocks()) {
for (BlockArgument arg : block.getArguments()) {
if (!isSecret(arg, solver)) {
continue;
bool argIsSecret = isSecret(arg, solver);

// handle iter_args that become secret via yield
if (!argIsSecret) {
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(op)) {
auto iterArgs = loopOp.getRegionIterArgs();
auto it = llvm::find(iterArgs, arg);
if (it != iterArgs.end()) {
unsigned idx = std::distance(iterArgs.begin(), it);
auto yieldedValues = loopOp.getYieldedValues();
if (idx < yieldedValues.size()) {
argIsSecret = isSecret(yieldedValues[idx], solver);
}
}
}
if (!argIsSecret) continue;
}

std::stringstream ss;
Expand All @@ -159,14 +206,22 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}
}
}
return WalkResult::advance();
});

// Constraints to initialize the key basis degree variables at the start of
// the computation.
for (auto& [value, var] : keyBasisVars) {
if (llvm::isa<BlockArgument>(value)) {
// If the dimension is 3, the key basis is [0, 1, 2] and the degree is 2.
auto constrainedDegree = getDimension(value, solver).value_or(2) - 1;
auto blockArg = llvm::cast<BlockArgument>(value);
int constrainedDegree;
// Loop iter_args is always assumed degree 1 since getDimension diverges
if (isa<LoopLikeOpInterface>(blockArg.getOwner()->getParentOp())) {
constrainedDegree = 1;
} else {
// If the dimension is 3, the key basis is [0, 1, 2] and the degree is 2.
constrainedDegree = getDimension(value, solver).value_or(2) - 1;
}
model.AddLinearConstraint(var == constrainedDegree, "");
}
}
Expand All @@ -179,15 +234,20 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
// through from the input unchanged. If we don't require this, the output
// of the addition must be a max over the input degrees.
if (!allowMixedDegreeOperands) {
opToRunOn->walk([&](Operation* op) {
opToRunOn->walk<WalkOrder::PreOrder>([&](Operation* op) -> WalkResult {
// Skip loop bodies — they will be handled by a recursive solver
if (isa<LoopLikeOpInterface>(op) && op != opToRunOn) {
return WalkResult::skip();
}

if (op->getNumOperands() <= 1) {
return;
return WalkResult::advance();
}

// secret generic op arguments are not constrained
// instead their block arguments are constrained
if (isa<secret::GenericOp>(op)) {
return;
return WalkResult::advance();
}

std::string name = uniqueName(op);
Expand All @@ -196,7 +256,7 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
SmallVector<OpOperand*, 4> secretOperands;
getSecretOperands(op, secretOperands, solver);
if (secretOperands.size() <= 1) {
return;
return WalkResult::advance();
}

auto anchorVar = keyBasisVars.at(secretOperands[0]->get());
Expand All @@ -215,13 +275,17 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
<< name;
model.AddLinearConstraint(operandDegreeVar == anchorVar, ss.str());
}
return WalkResult::advance();
});
}

// Some ops require a linear key basis. Yield is a special case
// where we require returned values from funcs to be linearized.
// TODO(#1398): determine whether we need linear key basis for modreduce.
opToRunOn->walk([&](Operation* op) {
opToRunOn->walk<WalkOrder::PreOrder>([&](Operation* op) -> WalkResult {
if (isa<LoopLikeOpInterface>(op) && op != opToRunOn) {
return WalkResult::skip();
}
llvm::TypeSwitch<Operation&>(*op)
.Case<tensor_ext::RotateOp, secret::YieldOp, mgmt::ModReduceOp>(
[&](auto op) {
Expand All @@ -244,7 +308,35 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
<< operand.getOperandNumber();
model.AddLinearConstraint(operandDegreeVar == 1, ss.str());
}
});
})
.Case<affine::AffineYieldOp, scf::YieldOp>([&](auto op) {
// For loop yield ops, the degree returned must not exceed the degree
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think the degrees should be equal. In particular, when this is lowered to the ckks scheme, having degrees that are not equal across iter args will produce a type error.

I think it would be a good and simplifying assumption to enforce by fiat that all iter_args have a linear key basis.

// of the corresponding iter_arg block argument at the start of the
// loop. This prevents unbounded growth across loop iterations.
auto parentLoop = op->getParentOp();
auto loopLike = dyn_cast<LoopLikeOpInterface>(parentLoop);
if (!loopLike) return;

auto iterArgs = loopLike.getRegionIterArgs();

// Number of iter args should match number of yielded operands
if (iterArgs.size() != op.getNumOperands()) return;

for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto yieldOperand = op.getOperand(i);
if (!isSecret(yieldOperand, solver)) continue;
if (!keyBasisVars.contains(yieldOperand)) continue;

auto yieldDegreeVar = keyBasisVars.at(yieldOperand);
auto iterArgDegreeVar = keyBasisVars.at(iterArgs[i]);

model.AddLinearConstraint(
yieldDegreeVar <= iterArgDegreeVar,
"LoopCarriedDependency_" + std::to_string(i) + "_" +
uniqueName(op));
}
});
return WalkResult::advance();
});

// When mixed-degree ops are enabled, the default result degree of an op is
Expand All @@ -254,7 +346,27 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
std::unordered_set<const math_opt::Variable*> extraVarsForObjective;

// Add constraints that set the before_relin variables appropriately
opToRunOn->walk([&](Operation* op) {
opToRunOn->walk<WalkOrder::PreOrder>([&](Operation* op) -> WalkResult {
// For nested inner loops, apply the LoopOutputDegree constraint using the
// degree solved by the inner loop's ILP, then skip the loop body.
if (isa<LoopLikeOpInterface>(op) && op != opToRunOn) {
for (auto [idx, result] : llvm::enumerate(op->getResults())) {
if (!isSecret(result, solver)) continue;
auto resultBeforeRelinVar = beforeRelinVars.at(result);
// Use the degree solved by the inner loop's ILP.
// Default to 1 if not yet populated (loop with no secret ops).
int solvedDegree = 1;
auto it = loopBoundaryDegrees.find(op);
if (it != loopBoundaryDegrees.end() && idx < it->second.size()) {
solvedDegree = it->second[idx];
}
model.AddLinearConstraint(
resultBeforeRelinVar == solvedDegree,
"LoopOutputDegree_" + uniqueName(op) + "_" + std::to_string(idx));
}
return WalkResult::skip();
}

llvm::TypeSwitch<Operation&>(*op)
.Case<arith::MulIOp, arith::MulFOp>([&](auto op) {
// if plain mul, skip
Expand Down Expand Up @@ -299,7 +411,8 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}
})
.Default([&](Operation& op) {
// For any other op, the key basis does not change unless we insert
if (isa<LoopLikeOpInterface>(op)) return;

// a relin op. The operands may have the same basis degree, if that
// is required by the backend and allowMixedDegreeOperands is false,
// in which case we can just forward the degree of the first secret
Expand Down Expand Up @@ -360,7 +473,8 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
}
}
});
});
return WalkResult::advance();
});

// The objective is to minimize the number of relinearization ops.
// TODO(#1018): improve the objective function to account for differing costs
Expand All @@ -373,7 +487,48 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
model.Minimize(obj);

// Add constraints that control the effect of relinearization insertion.
opToRunOn->walk([&](Operation* op) {
opToRunOn->walk<WalkOrder::PreOrder>([&](Operation* op) -> WalkResult {
// Helper to add DecisionDynamics constraints for an op's results.
auto addDecisionDynamics = [&](Operation* targetOp) {
if (!isSecret(targetOp->getResults(), solver)) return;
for (OpResult opResult : targetOp->getResults()) {
Value result = opResult;
if (!isSecret(result, solver)) continue;

auto resultBeforeRelinVar = beforeRelinVars.at(result);
auto resultAfterRelinVar = keyBasisVars.at(result);
auto insertRelinOpDecision = decisionVariables.at(targetOp);

std::string opName = uniqueName(targetOp);
std::string ddPrefix = "DecisionDynamics_" + opName + "_" +
std::to_string(opResult.getResultNumber());

model.AddLinearConstraint(resultAfterRelinVar >= insertRelinOpDecision,
ddPrefix + "_1");

model.AddLinearConstraint(
resultAfterRelinVar <= 1 + IF_THEN_AUX * (1 - insertRelinOpDecision),
ddPrefix + "_2");

model.AddLinearConstraint(
resultAfterRelinVar >=
resultBeforeRelinVar - IF_THEN_AUX * insertRelinOpDecision,
ddPrefix + "_3");

model.AddLinearConstraint(
resultAfterRelinVar <=
resultBeforeRelinVar + IF_THEN_AUX * insertRelinOpDecision,
ddPrefix + "_4");
}
};

// For nested inner loops, apply the DecisionDynamics constraints to the
// loop's results, then skip the loop body.
if (isa<LoopLikeOpInterface>(op) && op != opToRunOn) {
addDecisionDynamics(op);
return WalkResult::skip();
}

// We don't need a type switch here because the only difference
// between mul and other ops is how the before_relin variable is related to
// the operand variables.
Expand All @@ -386,42 +541,11 @@ LogicalResult OptimizeRelinearizationAnalysis::solve() {
// secret generic op arguments are not constrained
// instead their block arguments are constrained
if (isa<secret::GenericOp>(op)) {
return;
}
if (!isSecret(op->getResults(), solver)) {
return;
return WalkResult::advance();
}

for (OpResult opResult : op->getResults()) {
Value result = opResult;
auto resultBeforeRelinVar = beforeRelinVars.at(result);
auto resultAfterRelinVar = keyBasisVars.at(result);
auto insertRelinOpDecision = decisionVariables.at(op);
std::string opName = uniqueName(op);
std::string ddPrefix = "DecisionDynamics_" + opName + "_" +
std::to_string(opResult.getResultNumber());

cstName = ddPrefix + "_1";
model.AddLinearConstraint(resultAfterRelinVar >= insertRelinOpDecision,
cstName);

cstName = ddPrefix + "_2";
model.AddLinearConstraint(
resultAfterRelinVar <= 1 + IF_THEN_AUX * (1 - insertRelinOpDecision),
cstName);

cstName = ddPrefix + "_3";
model.AddLinearConstraint(
resultAfterRelinVar >=
resultBeforeRelinVar - IF_THEN_AUX * insertRelinOpDecision,
cstName);

cstName = ddPrefix + "_4";
model.AddLinearConstraint(
resultAfterRelinVar <=
resultBeforeRelinVar + IF_THEN_AUX * insertRelinOpDecision,
cstName);
}
addDecisionDynamics(op);
return WalkResult::advance();
});

// Dump the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class OptimizeRelinearizationAnalysis {
return solutionKeyBasisDegreeBeforeRelin.lookup(value);
}

/// Maps a loop operation to its output degrees (one int per loop result).
/// Populated by the inner solver and read by the outer solver.
llvm::DenseMap<Operation*, SmallVector<int>> loopBoundaryDegrees;

private:
Operation* opToRunOn;
DataFlowSolver* solver;
Expand Down
Loading
Loading