diff --git a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp index b790d010d3..693a6165bc 100644 --- a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp +++ b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp @@ -50,21 +50,7 @@ static void debugLog(StringRef opName, ArrayRef operands, }); }; -LevelState transferForward(mgmt::ModReduceOp op, - ArrayRef operands) { - LevelState result = std::visit( - Overloaded{ - [](MaxLevel) -> LevelState { return LevelState(Invalid{}); }, - [](Uninit) -> LevelState { return LevelState(Invalid{}); }, - [](Invalid) -> LevelState { return LevelState(Invalid{}); }, - [](int val) -> LevelState { return LevelState(val + 1); }, - }, - operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("mod_reduce", operands, result)); - return result; -} - -LevelState transferForward(mgmt::LevelReduceOp op, +LevelState transferForward(ReducesLevelOpInterface op, ArrayRef operands) { LevelState result = std::visit( Overloaded{ @@ -72,15 +58,15 @@ LevelState transferForward(mgmt::LevelReduceOp op, [](Uninit) -> LevelState { return LevelState(Invalid{}); }, [](Invalid) -> LevelState { return LevelState(Invalid{}); }, [&](int val) -> LevelState { - return LevelState(val + (int)op.getLevelToDrop()); + return LevelState(val + op.getLevelsToDrop()); }, }, operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("level_reduce", operands, result)); + LLVM_DEBUG(debugLog("ReduceLevelOpInterface", operands, result)); return result; } -LevelState transferForward(mgmt::LevelReduceMinOp op, +LevelState transferForward(ReducesAllLevelsOpInterface op, ArrayRef operands) { LevelState result = std::visit( Overloaded{ @@ -92,11 +78,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op, [](int val) -> LevelState { return LevelState(MaxLevel{}); }, }, operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("level_reduce_min", operands, result)); + LLVM_DEBUG(debugLog("ReduceAllLevelsOpInterface", operands, result)); return result; } -LevelState transferForward(mgmt::BootstrapOp op, +LevelState transferForward(ResetsLevelOpInterface op, ArrayRef operands) { LevelState result = std::visit( Overloaded{ @@ -106,15 +92,18 @@ LevelState transferForward(mgmt::BootstrapOp op, [](int val) -> LevelState { return LevelState(0); }, }, operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("bootstrap", operands, result)); + LLVM_DEBUG(debugLog("ResetsLevelOpInterface", operands, result)); return result; } LevelState deriveResultLevel(Operation* op, ArrayRef operands) { return llvm::TypeSwitch(*op) - .Case( + .Case( + [&](auto op) -> LevelState { return transferForward(op, operands); }) + .Case( + [&](auto op) -> LevelState { return transferForward(op, operands); }) + .Case( [&](auto op) -> LevelState { return transferForward(op, operands); }) .Default([&](auto& op) -> LevelState { LevelState result; diff --git a/lib/Dialect/HEIRInterfaces.td b/lib/Dialect/HEIRInterfaces.td index d217278cf4..a910e5ec4e 100644 --- a/lib/Dialect/HEIRInterfaces.td +++ b/lib/Dialect/HEIRInterfaces.td @@ -34,6 +34,41 @@ def ResetsMulDepthOpInterface : OpInterface<"ResetsMulDepthOpInterface"> { }]; } +def ResetsLevelOpInterface : OpInterface<"ResetsLevelOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + An interface that signals when an operation resets level + among its results, such as a `mgmt.bootstrap`. + }]; +} + +def ReducesLevelOpInterface : OpInterface<"ReducesLevelOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + An interface that signals when an operation reduces level + among its results, such as a `mgmt.mod_reduce` or `ckks.rescale`. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/"Return the number of levels to reduce by.", + /*retTy=*/"int", + /*methodName=*/"getLevelsToDrop", + /*args=*/(ins ), + /*body=*/[{}], + /*defaultBody=*/[{ return 1; }] + >, + ]; +} + +def ReducesAllLevelsOpInterface : OpInterface<"ReducesAllLevelsOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + An interface that signals when an operation reduces all level + among its results, such as a `mgmt.level_reduce_min`. + }]; +} + def LUTOpInterface : OpInterface<"LUTOpInterface"> { let cppNamespace = "::mlir::heir"; let description = [{ diff --git a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td index 56619e4c14..a24ff1fe87 100644 --- a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td @@ -182,8 +182,8 @@ def Lattigo_BGVMulOp : Lattigo_BGVBinaryInPlaceOp<"mul", [IncreasesMulDepthOpInt }]; } -class Lattigo_BGVUnaryOp : - Lattigo_BGVOp { +class Lattigo_BGVUnaryOp traits = []> : + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$input @@ -198,7 +198,7 @@ def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> { }]; } -def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> { +def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo BGV dialect. @@ -258,7 +258,7 @@ def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInPlaceOp<"relinearize"> { }]; } -def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale"> { +def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo BGV dialect. diff --git a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td index 81cb4a6c50..709b8aa7e9 100644 --- a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td @@ -215,8 +215,8 @@ def Lattigo_CKKSMulOp : Lattigo_CKKSBinaryInPlaceOp<"mul", [IncreasesMulDepthOpI }]; } -class Lattigo_CKKSUnaryOp : - Lattigo_CKKSOp { +class Lattigo_CKKSUnaryOp traits = []> : + Lattigo_CKKSOp { let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, Lattigo_RLWECiphertext:$input @@ -231,7 +231,7 @@ def Lattigo_CKKSRelinearizeNewOp : Lattigo_CKKSUnaryOp<"relinearize_new"> { }]; } -def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new"> { +def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo CKKS dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo CKKS dialect. @@ -284,7 +284,7 @@ def Lattigo_CKKSRelinearizeOp : Lattigo_CKKSUnaryInPlaceOp<"relinearize"> { }]; } -def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale"> { +def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo CKKS dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo CKKS dialect. @@ -322,7 +322,7 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [ let hasVerifier = 1; } -def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> { +def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap", [ResetsLevelOpInterface]> { let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect"; let description = [{ Bootstraps a ciphertext value in the Lattigo CKKS dialect. diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.cpp b/lib/Dialect/Lattigo/IR/LattigoOps.cpp index 10336f9a85..49e55a9c49 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.cpp +++ b/lib/Dialect/Lattigo/IR/LattigoOps.cpp @@ -47,6 +47,10 @@ LogicalResult RLWENewEncryptorOp::verify() { return success(); } +int RLWEDropLevelNewOp::getLevelsToDrop() { return getLevelToDrop(); } + +int RLWEDropLevelOp::getLevelsToDrop() { return getLevelToDrop(); } + LogicalResult BGVRotateColumnsNewOp::verify() { return containsExactlyOneOrEmitError(getOperation(), getDynamicShift(), getStaticShift()); diff --git a/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td b/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td index bfd230df23..e3d9a92d22 100644 --- a/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td @@ -122,7 +122,8 @@ def Lattigo_RLWEDecryptOp : Lattigo_RLWEOp<"decrypt"> { let results = (outs Lattigo_RLWEPlaintext:$plaintext); } -def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> { +def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new", + [DeclareOpInterfaceMethods]> { let summary = "Drop level of a ciphertext"; let arguments = (ins Lattigo_RLWEEvaluator:$evaluator, @@ -132,7 +133,8 @@ def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> { let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InPlaceOpInterface]> { +def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", + [InPlaceOpInterface, DeclareOpInterfaceMethods]> { let summary = "Drop level of a ciphertext"; let description = [{ This operation drops the level of a ciphertext diff --git a/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp b/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp index e4c8d29ba4..055298806a 100644 --- a/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp +++ b/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp @@ -2,36 +2,56 @@ #include +#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/Lattigo/IR/LattigoOps.h" #include "lib/Dialect/Lattigo/IR/LattigoTypes.h" #include "lib/Utils/AllocToInPlaceUtils.h" -#include "mlir/include/mlir/Analysis/Liveness.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/Liveness.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project +#define DEBUG_TYPE "alloc-to-inplace" + namespace mlir { namespace heir { namespace lattigo { +namespace { + +// Sets the level of a potentially newly created value. +static inline void setValueToLevel(DataFlowSolver* solver, Value value, + int level) { + auto* lattice = solver->getOrCreateState(value); + lattice->getValue().setLevel(level); +} + +} // namespace + template struct ConvertBinOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertBinOp(mlir::MLIRContext* context, Liveness* liveness, + DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(BinOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -45,13 +65,15 @@ struct ConvertBinOp : public OpRewritePattern { // Update storage info, which must happen before the op is removed storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); - + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -60,16 +82,17 @@ struct ConvertUnaryOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertUnaryOp( - mlir::MLIRContext* context, Liveness* liveness, + mlir::MLIRContext* context, Liveness* liveness, DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(UnaryOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -82,12 +105,15 @@ struct ConvertUnaryOp : public OpRewritePattern { op.getOperand(0), op.getOperand(1), storage); storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -96,16 +122,17 @@ struct ConvertRotateOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertRotateOp( - mlir::MLIRContext* context, Liveness* liveness, + mlir::MLIRContext* context, Liveness* liveness, DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(RotateOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -132,12 +159,15 @@ struct ConvertRotateOp : public OpRewritePattern { // update storage info storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -146,16 +176,17 @@ struct ConvertDropLevelOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertDropLevelOp( - mlir::MLIRContext* context, Liveness* liveness, + mlir::MLIRContext* context, Liveness* liveness, DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(DropLevelOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -169,12 +200,15 @@ struct ConvertDropLevelOp : public OpRewritePattern { // update storage info storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -185,6 +219,14 @@ struct AllocToInPlace : impl::AllocToInPlaceBase { using AllocToInPlaceBase::AllocToInPlaceBase; void runOnOperation() override { + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + } Liveness liveness(getOperation()); MLIRContext* context = &getContext(); @@ -213,8 +255,8 @@ struct AllocToInPlace : impl::AllocToInPlaceBase { // RLWE ConvertUnaryOp, ConvertDropLevelOp>(context, &liveness, - &blockToStorageInfo); + lattigo::RLWEDropLevelOp>>( + context, &liveness, &solver, &blockToStorageInfo); // The greedy policy relies on the order of processing the operations. walkAndApplyPatterns(getOperation(), std::move(patterns)); diff --git a/lib/Dialect/Lattigo/Transforms/BUILD b/lib/Dialect/Lattigo/Transforms/BUILD index 3aa2461da7..09edc912b6 100644 --- a/lib/Dialect/Lattigo/Transforms/BUILD +++ b/lib/Dialect/Lattigo/Transforms/BUILD @@ -23,9 +23,11 @@ cc_library( hdrs = ["AllocToInPlace.h"], deps = [ ":pass_inc_gen", + "@heir//lib/Analysis/LevelAnalysis", + "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect/Lattigo/IR:Dialect", "@heir//lib/Utils:AllocToInPlaceUtils", - "@heir//lib/Utils/Tablegen:InPlaceOpInterface", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/Mgmt/IR/MgmtOps.cpp b/lib/Dialect/Mgmt/IR/MgmtOps.cpp index 3806b30b4e..89d30c4ee5 100644 --- a/lib/Dialect/Mgmt/IR/MgmtOps.cpp +++ b/lib/Dialect/Mgmt/IR/MgmtOps.cpp @@ -1,5 +1,6 @@ #include "lib/Dialect/Mgmt/IR/MgmtOps.h" +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Mgmt/IR/MgmtPatterns.h" #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project @@ -49,6 +50,8 @@ void cleanupInitOp(Operation* top) { }); } +int LevelReduceOp::getLevelsToDrop() { return getLevelToDrop(); } + } // namespace mgmt } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Mgmt/IR/MgmtOps.td b/lib/Dialect/Mgmt/IR/MgmtOps.td index 8b899bbe43..82f9095a63 100644 --- a/lib/Dialect/Mgmt/IR/MgmtOps.td +++ b/lib/Dialect/Mgmt/IR/MgmtOps.td @@ -15,7 +15,7 @@ class Mgmt_Op traits = []> : let cppNamespace = "::mlir::heir::mgmt"; } -def Mgmt_ModReduceOp : Mgmt_Op<"modreduce"> { +def Mgmt_ModReduceOp : Mgmt_Op<"modreduce", [ReducesLevelOpInterface]> { let summary = "Modulus switch the input ciphertext down by one limb (RNS assumed)"; let description = [{ @@ -35,7 +35,8 @@ def Mgmt_ModReduceOp : Mgmt_Op<"modreduce"> { let hasCanonicalizer = 1; } -def Mgmt_LevelReduceOp : Mgmt_Op<"level_reduce"> { +def Mgmt_LevelReduceOp : Mgmt_Op<"level_reduce", + [DeclareOpInterfaceMethods]> { let summary = "Reduce the level of input ciphertext by dropping the last k RNS limbs"; let description = [{ @@ -60,7 +61,7 @@ def Mgmt_LevelReduceOp : Mgmt_Op<"level_reduce"> { let hasCanonicalizer = 1; } -def Mgmt_LevelReduceMinOp : Mgmt_Op<"level_reduce_min"> { +def Mgmt_LevelReduceMinOp : Mgmt_Op<"level_reduce_min", [ReducesAllLevelsOpInterface]> { let summary = "Reduce the level of input ciphertext to the minimum level"; let description = [{ This scheme-agonistic operation reduces the ciphertext level @@ -99,7 +100,7 @@ def Mgmt_RelinearizeOp : Mgmt_Op<"relinearize"> { let assemblyFormat = "operands attr-dict `:` type($output)"; } -def Mgmt_BootstrapOp : Mgmt_Op<"bootstrap", [ResetsMulDepthOpInterface]> { +def Mgmt_BootstrapOp : Mgmt_Op<"bootstrap", [ResetsMulDepthOpInterface, ResetsLevelOpInterface]> { let summary = "Bootstrap the input ciphertext to refresh its noise budget"; let description = [{ diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index db01679559..bc75c439ee 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -505,8 +505,7 @@ BackendPipelineBuilder toLattigoPipelineBuilder() { pm.addPass(lwe::createLWEToLattigo()); // Convert Alloc Ops to InPlace Ops - // TODO(#2635): Disable until this is fixed. - // pm.addPass(lattigo::createAllocToInPlace()); + pm.addPass(lattigo::createAllocToInPlace()); // Simplify, in case the lowering revealed redundancy pm.addPass(createCanonicalizerPass()); diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index 784a5cfb13..0122cb9c40 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -1942,6 +1942,10 @@ LogicalResult LattigoEmitter::printOperation(CKKSRescaleOp op) { } LogicalResult LattigoEmitter::printOperation(CKKSRotateOp op) { + auto inputName = getName(op.getInput()); + auto inplaceName = getName(op.getInplace()); + os << inplaceName << ".Resize(" << inputName << ".Degree()," << inputName + << ".Level())\n"; auto errName = getErrName(); os << errName << " := " << getName(op.getEvaluator()) << ".Rotate("; os << getName(op.getInput()) << ", "; diff --git a/lib/Utils/AllocToInPlaceUtils.h b/lib/Utils/AllocToInPlaceUtils.h index 026687cf6a..cdf293a0ad 100644 --- a/lib/Utils/AllocToInPlaceUtils.h +++ b/lib/Utils/AllocToInPlaceUtils.h @@ -1,3 +1,7 @@ +#include + +#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #ifndef LIB_UTILS_ALLOCTOINPLACEUTILS_H_ #define LIB_UTILS_ALLOCTOINPLACEUTILS_H_ @@ -21,6 +25,18 @@ namespace mlir { namespace heir { +static std::optional getLevel(Value value, DataFlowSolver* solver) { + auto* lattice = solver->lookupState(value); + if (!lattice || !lattice->getValue().isInitialized()) { + return std::nullopt; + } + auto latticeVal = lattice->getValue(); + if (!latticeVal.isInt()) { + return std::nullopt; + } + return lattice->getValue().getInt(); +} + // CallerProvidedStorageInfo provides an analysis of SSA values that // can be reused for in-place operations that require the caller to pass // in pre-allocated memory for the operation to use. @@ -104,11 +120,22 @@ class CallerProvidedStorageInfo { // various accelerators. One basic optimization is to use the dead value that // is closest to the current operation in the block. But as we do not have the // information of the memory layout, we do not implement this optimization. - Value getAvailableStorage(Operation* op, Liveness* liveness) const { + Value getAvailableStorage(Operation* op, Liveness* liveness, + DataFlowSolver* solver) const { LLVM_DEBUG(llvm::dbgs() << "getAvailableStorage for op " << op->getName() << "\n"); for (auto& [storage, values] : storageToReferringValues) { // storage and all referring values are dead + if (solver) { + auto opLevel = getLevel(op->getResult(0), solver); + auto storageLevel = getLevel(storage, solver); + if (!opLevel.has_value() || !storageLevel.has_value()) { + continue; + } + if (opLevel.value() != storageLevel.value()) { + continue; + } + } if (std::all_of( values.begin(), values.end(), [&](Value value) { return liveness->isDeadAfter(value, op); }) && diff --git a/lib/Utils/BUILD b/lib/Utils/BUILD index 36896dc72d..c81ba077d1 100644 --- a/lib/Utils/BUILD +++ b/lib/Utils/BUILD @@ -232,7 +232,9 @@ cc_library( srcs = ["AllocToInPlaceUtils.cpp"], hdrs = ["AllocToInPlaceUtils.h"], deps = [ + "@heir//lib/Analysis/LevelAnalysis", "@heir//lib/Utils/Tablegen:InPlaceOpInterface", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir new file mode 100644 index 0000000000..6fe69c3ece --- /dev/null +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir @@ -0,0 +1,23 @@ +// RUN: heir-opt --lattigo-alloc-to-inplace %s | FileCheck %s + +// Use the minimum level level of the two operands for the result storage + +!evaluator = !lattigo.bgv.evaluator +!ct = !lattigo.rlwe.ciphertext + +// CHECK: ![[evaluator:.*]] = !lattigo.bgv.evaluator + +// CHECK: func.func @drop_level +// CHECK-SAME: %[[evaluator:.*]]: ![[evaluator]] +func.func @drop_level(%evaluator : !evaluator, %ct : !ct) -> !ct { + %ct_level_0 = lattigo.bgv.rotate_columns_new %evaluator, %ct {static_shift = 4} : (!evaluator, !ct) -> !ct + // CHECK: %[[ct_level_2:.*]] = lattigo.rlwe.drop_level_new + // CHECK-SAME: levelToDrop = 2 + // CHECK: %[[ct_level_4:.*]] = lattigo.rlwe.drop_level_new + // CHECK-SAME: levelToDrop = 4 + %0 = lattigo.rlwe.drop_level_new %evaluator, %ct { levelToDrop = 2 } : (!evaluator, !ct) -> !ct + %1 = lattigo.rlwe.drop_level_new %evaluator, %ct_level_0 { levelToDrop = 4 } : (!evaluator, !ct) -> !ct + // CHECK: lattigo.bgv.add %[[evaluator]], %[[ct_level_2]], %[[ct_level_4]], %[[ct_level_4]] + %2 = lattigo.bgv.add_new %evaluator, %0, %1 : (!evaluator, !ct, !ct) -> !ct + return %2 : !ct +} diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir index 481f73df66..de8729b2d2 100644 --- a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir @@ -3,7 +3,15 @@ // CHECK: func.func @dot_product func.func @dot_product(%evaluator: !lattigo.bgv.evaluator, %param: !lattigo.bgv.parameter, %encoder: !lattigo.bgv.encoder, %ct: !lattigo.rlwe.ciphertext, %ct_0: !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { // no new allocation found as the two ciphertexts in function argument are enough to store the imtermediate results - // CHECK-NOT: _new + // a new allocation is only needed for the rescale because of level change + // CHECK-NOT: mul_new + // CHECK-NOT: relinearize_new + // CHECK-NOT: rotate_columns_new + // CHECK-NOT: add_new + // CHECK: rescale_new + // CHECK-NOT: mul_new + // CHECK-NOT: rotate_columns_new + // CHECK: return %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir index db39d939d5..290bd7f103 100644 --- a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir @@ -12,7 +12,15 @@ module attributes {bgv.schemeParam = #bgv.scheme_param !ct attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { // no new allocation found as the two ciphertexts in function argument are enough to store the imtermediate results - // CHECK-NOT: _new + // a new allocation is only needed for the rescale because of level change + // CHECK-NOT: mul_new + // CHECK-NOT: relinearize_new + // CHECK-NOT: rotate_columns_new + // CHECK-NOT: add_new + // CHECK: rescale_new + // CHECK-NOT: mul_new + // CHECK-NOT: rotate_columns_new + // CHECK: return %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index @@ -37,7 +45,14 @@ module attributes {bgv.schemeParam = #bgv.scheme_param !ct attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { // no new allocation found as the two ciphertexts in function argument are enough to store the imtermediate results - // CHECK-NOT: _new + // CHECK-NOT: mul_new + // CHECK-NOT: relinearize_new + // CHECK-NOT: rotate_columns_new + // CHECK-NOT: add_new + // CHECK: rescale_new + // CHECK-NOT: mul_new + // CHECK-NOT: rotate_columns_new + // CHECK: return %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index diff --git a/tests/Dialect/Mgmt/Transforms/level_reduce.mlir b/tests/Dialect/Mgmt/Transforms/level_reduce.mlir new file mode 100644 index 0000000000..31f7054ee5 --- /dev/null +++ b/tests/Dialect/Mgmt/Transforms/level_reduce.mlir @@ -0,0 +1,13 @@ +// RUN: heir-opt --annotate-mgmt %s | FileCheck %s + +func.func @main(%arg0: !secret.secret>) -> !secret.secret> { + // CHECK: secret.generic + // CHECK-SAME: level = 2 + %b = secret.generic(%arg0: !secret.secret>) { + ^body(%clear_a: tensor<8xi8>): + %c = mgmt.level_reduce %clear_a { levelToDrop = 2 }: tensor<8xi8> + secret.yield %c : tensor<8xi8> + // CHECK: } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> !secret.secret> + func.return %b : !secret.secret> +} diff --git a/tests/Examples/lattigo/ckks/in_place/BUILD b/tests/Examples/lattigo/ckks/in_place/BUILD new file mode 100644 index 0000000000..9e93ad639c --- /dev/null +++ b/tests/Examples/lattigo/ckks/in_place/BUILD @@ -0,0 +1,30 @@ +# See README.md for setup required to run these tests + +load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib") +load("@rules_go//go:def.bzl", "go_test") + +package(default_applicable_licenses = ["@heir//:license"]) + +heir_lattigo_lib( + name = "in_place", + go_library_name = "main", + heir_opt_flags = [ + "--canonicalize", + "--cse", + "--scheme-to-lattigo", + ], + mlir_src = "in_place.mlir", +) + +# For Google-internal reasons we must separate the go_test rules from the macro +# above. + +go_test( + name = "in_place_test", + srcs = ["in_place_test.go"], + embed = [":main"], + deps = [ + "@com_github_tuneinsight_lattigo_v6//core/rlwe", + "@com_github_tuneinsight_lattigo_v6//schemes/ckks", + ], +) diff --git a/tests/Examples/lattigo/ckks/in_place/in_place.mlir b/tests/Examples/lattigo/ckks/in_place/in_place.mlir new file mode 100644 index 0000000000..892595f837 --- /dev/null +++ b/tests/Examples/lattigo/ckks/in_place/in_place.mlir @@ -0,0 +1,46 @@ +!Z1056763241666817029_i64 = !mod_arith.int<1056763241666817029 : i64> +!Z1106058412451299513_i64 = !mod_arith.int<1106058412451299513 : i64> +!Z957769724367225479_i64 = !mod_arith.int<957769724367225479 : i64> +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding1 = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding2 = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +#modulus_chain_L1_C1 = #lwe.modulus_chain, current = 1> +#modulus_chain_L2_C2 = #lwe.modulus_chain, current = 2> +#ring_f64_1_x131072 = #polynomial.ring> +!rns_L1 = !rns.rns +!rns_L2 = !rns.rns +!pt = !lwe.lwe_plaintext> +#ring_rns_L1_1_x131072 = #polynomial.ring> +#ring_rns_L2_1_x131072 = #polynomial.ring> +#ciphertext_space_L1 = #lwe.ciphertext_space +#ciphertext_space_L2 = #lwe.ciphertext_space +!ct_L1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> +!ct_L2 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L2_C2> +!ct_L2_1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L2_C2> +module attributes {ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { + func.func @in_place(%ct: !ct_L2) -> !ct_L1 { + %cst = arith.constant dense<0.000000e+00> : tensor<65536xf64> + %ct_0 = ckks.rotate %ct {static_shift = 0 : i32} : !ct_L2 + %pt = lwe.rlwe_encode %cst {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt + %ct_1 = ckks.mul_plain %ct_0, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_2 = ckks.rescale %ct_1 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_3 = ckks.rotate %ct {static_shift = 1 : i32} : !ct_L2 + %ct_4 = ckks.mul_plain %ct_3, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_5 = ckks.rescale %ct_4 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_6 = ckks.add %ct_2, %ct_5 : (!ct_L1, !ct_L1) -> !ct_L1 + %ct_7 = ckks.rotate %ct {static_shift = 2 : i32} : !ct_L2 + %ct_8 = ckks.mul_plain %ct_7, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_9 = ckks.rescale %ct_8 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_10 = ckks.add %ct_6, %ct_9 : (!ct_L1, !ct_L1) -> !ct_L1 + %ct_11 = ckks.rotate %ct {static_shift = 3 : i32} : !ct_L2 + %ct_12 = ckks.mul_plain %ct_11, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_13 = ckks.rescale %ct_12 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_14 = ckks.add %ct_10, %ct_13 : (!ct_L1, !ct_L1) -> !ct_L1 + %ct_15 = ckks.rotate %ct {static_shift = 4 : i32} : !ct_L2 + %ct_16 = ckks.mul_plain %ct_15, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_17 = ckks.rescale %ct_16 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_18 = ckks.add %ct_14, %ct_17 : (!ct_L1, !ct_L1) -> !ct_L1 + return %ct_18 : !ct_L1 + } +} diff --git a/tests/Examples/lattigo/ckks/in_place/in_place_test.go b/tests/Examples/lattigo/ckks/in_place/in_place_test.go new file mode 100644 index 0000000000..a7445e16aa --- /dev/null +++ b/tests/Examples/lattigo/ckks/in_place/in_place_test.go @@ -0,0 +1,105 @@ +package main + +import ( + "fmt" + "testing" + "time" + + "github.com/tuneinsight/lattigo/v6/core/rlwe" + "github.com/tuneinsight/lattigo/v6/schemes/ckks" +) + +// MakeFlattenedOnes creates a slice of float64 filled with 1.0s. +// The size of the slice is determined by the product of the input 2D dimensions (rows * cols). +func MakeFlattenedOnes(rows, cols int) []float64 { + size := rows * cols + tensor := make([]float64, size) + for i := range tensor { + tensor[i] = 1.0 + } + return tensor +} + +func makeRange(n int) []int { + a := make([]int, n) + for i := range a { + a[i] = i + } + return a +} + +func generateGalEls(param ckks.Parameters, indices []int) []uint64 { + var galEls []uint64 + for _, index := range indices { + galEls = append(galEls, param.GaloisElement(index)) + } + return galEls +} + +func TestMLP(t *testing.T) { + logN := 14 + numSlots := 1 << (logN - 1) + + // Input is arbitrary, doesn't matter since we're just testing + // performance + inputClear := make([]float64, numSlots) + for i := range inputClear { + inputClear[i] = 1.0 + } + + // Function args: + // + // %ct: encrypted input, + + // These parameters should match the mlir file, though due to the weird + // nature of this test, this is the source of truth for what is used, + // not the mlir file. + logQ := make([]int, 7) + for i := range logQ { + logQ[i] = 60 + } + param, err := ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + LogN: logN, + LogQ: logQ, + LogP: []int{60}, + LogDefaultScale: 40, + }) + if err != nil { + panic(err) + } + + encoder := ckks.NewEncoder(param) + kgen := rlwe.NewKeyGenerator(param) + sk, pk := kgen.GenKeyPairNew() + encryptor := rlwe.NewEncryptor(param, pk) + rk := kgen.GenRelinearizationKeyNew(sk) + + // We have to do this once for each distinct linear_transform op to + // ensure we generate all the galois keys needed by lattigo + var galEls []uint64 + // Manually add Galois key for extra rotation indices used in the + // mlir file, outside of linear_transform + // + // For some reason I need to manually add rotation keys used in + // linear_transform! That should have been handled by the above code... + rotIndices := makeRange(10) + galEls = append(galEls, generateGalEls(param, rotIndices)...) + + fmt.Printf("Final galEls: %v\n", galEls) + + evk := rlwe.NewMemEvaluationKeySet(rk, kgen.GenGaloisKeysNew(galEls, sk)...) + evaluator := ckks.NewEvaluator(param, evk) + + pt := ckks.NewPlaintext(param, 2) + encoder.Encode(inputClear, pt) + ctInput, err25 := encryptor.EncryptNew(pt) + if err25 != nil { + panic(err25) + } + + fmt.Printf("Starting call") + startTime := time.Now() + in_place(evaluator, param, encoder, ctInput) + duration := time.Since(startTime) + fmt.Printf("MLP call took: %v\n", duration) +}