Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,23 @@ static void debugLog(StringRef opName, ArrayRef<const LevelLattice*> operands,
});
};

LevelState transferForward(mgmt::ModReduceOp op,
ArrayRef<const LevelLattice*> operands) {
LevelState result = std::visit(
Overloaded{
[](MaxLevel) -> LevelState { return LevelState(Invalid{}); },
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
[](int val) -> LevelState { return LevelState(val + 1); },
},
operands[0]->getValue().get());
LLVM_DEBUG(debugLog("mod_reduce", operands, result));
return result;
}

LevelState transferForward(mgmt::LevelReduceOp op,
LevelState transferForward(ReducesLevelOpInterface op,
ArrayRef<const LevelLattice*> operands) {
LevelState result = std::visit(
Overloaded{
[](MaxLevel) -> LevelState { return LevelState(Invalid{}); },
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
[&](int val) -> LevelState {
return LevelState(val + (int)op.getLevelToDrop());
return LevelState(val + op.getLevelsToDrop());
},
},
operands[0]->getValue().get());
LLVM_DEBUG(debugLog("level_reduce", operands, result));
LLVM_DEBUG(debugLog("ReduceLevelOpInterface", operands, result));
return result;
}

LevelState transferForward(mgmt::LevelReduceMinOp op,
LevelState transferForward(ReducesAllLevelsOpInterface op,
ArrayRef<const LevelLattice*> operands) {
LevelState result = std::visit(
Overloaded{
Expand All @@ -92,11 +78,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op,
[](int val) -> LevelState { return LevelState(MaxLevel{}); },
},
operands[0]->getValue().get());
LLVM_DEBUG(debugLog("level_reduce_min", operands, result));
LLVM_DEBUG(debugLog("ReduceAllLevelsOpInterface", operands, result));
return result;
}

LevelState transferForward(mgmt::BootstrapOp op,
LevelState transferForward(ResetsLevelOpInterface op,
ArrayRef<const LevelLattice*> operands) {
LevelState result = std::visit(
Overloaded{
Expand All @@ -106,15 +92,18 @@ LevelState transferForward(mgmt::BootstrapOp op,
[](int val) -> LevelState { return LevelState(0); },
},
operands[0]->getValue().get());
LLVM_DEBUG(debugLog("bootstrap", operands, result));
LLVM_DEBUG(debugLog("ResetsLevelOpInterface", operands, result));
return result;
}

LevelState deriveResultLevel(Operation* op,
ArrayRef<const LevelLattice*> operands) {
return llvm::TypeSwitch<Operation&, LevelState>(*op)
.Case<mgmt::ModReduceOp, mgmt::LevelReduceOp, mgmt::BootstrapOp,
mgmt::LevelReduceMinOp>(
.Case<ResetsLevelOpInterface>(
[&](auto op) -> LevelState { return transferForward(op, operands); })
.Case<ReducesAllLevelsOpInterface>(
[&](auto op) -> LevelState { return transferForward(op, operands); })
.Case<ReducesLevelOpInterface>(
[&](auto op) -> LevelState { return transferForward(op, operands); })
.Default([&](auto& op) -> LevelState {
LevelState result;
Expand Down
35 changes: 35 additions & 0 deletions lib/Dialect/HEIRInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,41 @@ def ResetsMulDepthOpInterface : OpInterface<"ResetsMulDepthOpInterface"> {
}];
}

def ResetsLevelOpInterface : OpInterface<"ResetsLevelOpInterface"> {
let cppNamespace = "::mlir::heir";
let description = [{
An interface that signals when an operation resets level
among its results, such as a `mgmt.bootstrap`.
}];
}

def ReducesLevelOpInterface : OpInterface<"ReducesLevelOpInterface"> {
let cppNamespace = "::mlir::heir";
let description = [{
An interface that signals when an operation reduces level
among its results, such as a `mgmt.mod_reduce` or `ckks.rescale`.
}];

let methods = [
InterfaceMethod<
/*desc=*/"Return the number of levels to reduce by.",
/*retTy=*/"int",
/*methodName=*/"getLevelsToDrop",
/*args=*/(ins ),
/*body=*/[{}],
/*defaultBody=*/[{ return 1; }]
>,
];
}

def ReducesAllLevelsOpInterface : OpInterface<"ReducesAllLevelsOpInterface"> {
let cppNamespace = "::mlir::heir";
let description = [{
An interface that signals when an operation reduces all level
among its results, such as a `mgmt.level_reduce_min`.
}];
}

def LUTOpInterface : OpInterface<"LUTOpInterface"> {
let cppNamespace = "::mlir::heir";
let description = [{
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def Lattigo_BGVMulOp : Lattigo_BGVBinaryInPlaceOp<"mul", [IncreasesMulDepthOpInt
}];
}

class Lattigo_BGVUnaryOp<string mnemonic> :
Lattigo_BGVOp<mnemonic> {
class Lattigo_BGVUnaryOp<string mnemonic, list<Trait> traits = []> :
Lattigo_BGVOp<mnemonic, traits> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input
Expand All @@ -198,7 +198,7 @@ def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> {
}];
}

def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> {
def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new", [ReducesLevelOpInterface]> {
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo BGV dialect.
Expand Down Expand Up @@ -258,7 +258,7 @@ def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInPlaceOp<"relinearize"> {
}];
}

def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale"> {
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> {
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo BGV dialect.
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/Lattigo/IR/LattigoCKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def Lattigo_CKKSMulOp : Lattigo_CKKSBinaryInPlaceOp<"mul", [IncreasesMulDepthOpI
}];
}

class Lattigo_CKKSUnaryOp<string mnemonic> :
Lattigo_CKKSOp<mnemonic> {
class Lattigo_CKKSUnaryOp<string mnemonic, list<Trait> traits = []> :
Lattigo_CKKSOp<mnemonic, traits> {
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input
Expand All @@ -231,7 +231,7 @@ def Lattigo_CKKSRelinearizeNewOp : Lattigo_CKKSUnaryOp<"relinearize_new"> {
}];
}

def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new"> {
def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new", [ReducesLevelOpInterface]> {
let summary = "Rescale a ciphertext in the Lattigo CKKS dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo CKKS dialect.
Expand Down Expand Up @@ -284,7 +284,7 @@ def Lattigo_CKKSRelinearizeOp : Lattigo_CKKSUnaryInPlaceOp<"relinearize"> {
}];
}

def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale"> {
def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> {
let summary = "Rescale a ciphertext in the Lattigo CKKS dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo CKKS dialect.
Expand Down Expand Up @@ -322,7 +322,7 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [
let hasVerifier = 1;
}

def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> {
def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap", [ResetsLevelOpInterface]> {
let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect";
let description = [{
Bootstraps a ciphertext value in the Lattigo CKKS dialect.
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ LogicalResult RLWENewEncryptorOp::verify() {
return success();
}

int RLWEDropLevelNewOp::getLevelsToDrop() { return getLevelToDrop(); }

int RLWEDropLevelOp::getLevelsToDrop() { return getLevelToDrop(); }

LogicalResult BGVRotateColumnsNewOp::verify() {
return containsExactlyOneOrEmitError(getOperation(), getDynamicShift(),
getStaticShift());
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/Lattigo/IR/LattigoRLWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def Lattigo_RLWEDecryptOp : Lattigo_RLWEOp<"decrypt"> {
let results = (outs Lattigo_RLWEPlaintext:$plaintext);
}

def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> {
def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new",
[DeclareOpInterfaceMethods<ReducesLevelOpInterface, ["getLevelsToDrop"]>]> {
let summary = "Drop level of a ciphertext";
let arguments = (ins
Lattigo_RLWEEvaluator:$evaluator,
Expand All @@ -132,7 +133,8 @@ def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> {
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InPlaceOpInterface]> {
def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level",
[InPlaceOpInterface, DeclareOpInterfaceMethods<ReducesLevelOpInterface, ["getLevelsToDrop"]>]> {
let summary = "Drop level of a ciphertext";
let description = [{
This operation drops the level of a ciphertext
Expand Down
Loading
Loading