diff --git a/lib/Target/Lattigo/BUILD b/lib/Target/Lattigo/BUILD index 89cb472230..26980874f4 100644 --- a/lib/Target/Lattigo/BUILD +++ b/lib/Target/Lattigo/BUILD @@ -36,3 +36,51 @@ cc_library( "@llvm-project//mlir:TranslateLib", ], ) + +cc_library( + name = "LattigoDebugEmitter", + srcs = ["LattigoDebugEmitter.cpp"], + hdrs = [ + "LattigoDebugEmitter.h", + "LattigoTemplates.h", + ], + deps = [ + "@heir//lib/Dialect:ModuleAttributes", + "@heir//lib/Dialect/Lattigo/IR:Dialect", + "@heir//lib/Utils:TargetUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + ], +) + +cc_library( + name = "LattigoRegistration", + srcs = [ + "LattigoTranslateRegistration.cpp", + ], + hdrs = [ + "LattigoTranslateRegistration.h", + ], + deps = [ + ":LattigoDebugEmitter", + ":LattigoEmitter", + "@heir//lib/Dialect:ModuleAttributes", + "@heir//lib/Dialect/Lattigo/IR:Dialect", + "@heir//lib/Dialect/Mgmt/IR:Dialect", + "@heir//lib/Dialect/RNS/IR:Dialect", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + ], +) diff --git a/lib/Target/Lattigo/LattigoDebugEmitter.cpp b/lib/Target/Lattigo/LattigoDebugEmitter.cpp new file mode 100644 index 0000000000..53ff822dff --- /dev/null +++ b/lib/Target/Lattigo/LattigoDebugEmitter.cpp @@ -0,0 +1,247 @@ +#include "lib/Target/Lattigo/LattigoDebugEmitter.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "lib/Dialect/Lattigo/IR/LattigoDialect.h" +#include "lib/Dialect/Lattigo/IR/LattigoOps.h" +#include "lib/Dialect/Lattigo/IR/LattigoTypes.h" +#include "lib/Dialect/ModuleAttributes.h" +#include "lib/Target/Lattigo/LattigoTemplates.h" +#include "lib/Utils/TargetUtils.h" +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/ManagedStatic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project +#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lattigo { + +LogicalResult translateToDebugEmitter(mlir::Operation* op, + llvm::raw_ostream& os, + const std::string& packageName) { + LattigoDebugEmitter emitter(os, packageName); + LogicalResult result = emitter.translate(*op); + return result; +} + +FailureOr LattigoDebugEmitter::convertType(Type type) { + return llvm::TypeSwitch>(type) + // RLWE + .Case( + [&](auto ty) { return std::string("*rlwe.Ciphertext"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.Plaintext"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.PrivateKey"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.PublicKey"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.KeyGenerator"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.RelinearizationKey"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.GaloisKey"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.EvaluationKeySet"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.Encryptor"); }) + .Case( + [&](auto ty) { return std::string("*rlwe.Decryptor"); }) + .Case( + [&](auto ty) { return std::string("*bgv.Encoder"); }) + .Case( + [&](auto ty) { return std::string("*bgv.Evaluator"); }) + .Case( + [&](auto ty) { return std::string("bgv.Parameters"); }) + .Case( + [&](auto ty) { return std::string("*ckks.Encoder"); }) + .Case( + [&](auto ty) { return std::string("*ckks.Evaluator"); }) + .Case( + [&](auto ty) { return std::string("*bootstrapping.EvaluationKeys"); }) + .Case( + [&](auto ty) { return std::string("*bootstrapping.Evaluator"); }) + .Case( + [&](auto ty) { return std::string("ckks.Parameters"); }) + .Case( + [&](auto ty) { return std::string("bootstrapping.Parameters"); }) + .Default([&](Type) -> FailureOr { return failure(); }); +} + +LogicalResult LattigoDebugEmitter::emitDebugHelperSignature( + ::mlir::func::FuncOp funcOp, ErrorEmitterFn emitError) { + auto argTypes = funcOp.getArgumentTypes(); + + if (argTypes.size() != 5) { + return emitError( + funcOp.getLoc(), + llvm::formatv( + "Unexpected debug port signature: expected 5 args, got {0}", + argTypes.size())); + } + + llvm::SmallVector funcArgs; + for (size_t i = 0; i < argTypes.size(); i++) { + auto param = convertType(argTypes[i]); + if (failed(param)) + return emitError( + funcOp.getLoc(), + llvm::formatv("Failed to emit type for arg{0}: {1}", i, argTypes[i])); + + funcArgs.push_back(param.value()); + } + + os << "func"; + os << " " << canonicalizeDebugPort(funcOp.getName()) << "("; + + os << kEvalVar << " " << funcArgs[0] << ", "; + os << kParamVar << " " << funcArgs[1] << ", "; + os << kEncodeVar << " " << funcArgs[2] << ", "; + os << kDecryptVar << " " << funcArgs[3] << ", "; + os << kCiphertxtVar << " " << funcArgs[4] << ", "; + os << kDebugAttrMapParam; + os << " " << "map[string]string"; + os << ")"; + return success(); +} + +LogicalResult LattigoDebugEmitter::emitDebugHelperImpl() { + os << "isBlockArgument" << " := " << kDebugAttrMapParam + << "[\"asm.is_block_arg\"]\n"; + + os << "if isBlockArgument == \"1\" {\n"; + os.indent(); + os << "fmt.Println(\"Input\")\n"; + os.unindent(); + os << "} else {\n"; + os.indent(); + os << "fmt.Println(" << kDebugAttrMapParam << "[\"asm.op_name\"])\n"; + os.unindent(); + os << "}\n\n"; + + os << "messageSize, _ := strconv.Atoi(" << kDebugAttrMapParam + << "[\"message.size\"])\n"; + os << "value := make([]int64, messageSize)\n"; + os << "pt := " << kDecryptVar << ".DecryptNew(" << kCiphertxtVar << ")\n"; + os << kEncodeVar << ".Decode(pt, value)\n"; + os << "fmt.Printf(\" %v\\n\", value)\n"; + return success(); +} + +LogicalResult LattigoDebugEmitter::translate(Operation& op) { + LogicalResult status = + llvm::TypeSwitch(op) + // Builtin ops + .Case([&](auto op) { return printOperation(op); }) + // Func ops + .Case([&](auto op) { return printOperation(op); }) + .Default([&](Operation&) { + return emitError(op.getLoc(), "unable to find printer for op"); + }); + + if (failed(status)) { + return emitError(op.getLoc(), + llvm::formatv("Failed to translate op {0}", op.getName())); + } + return success(); +} + +LogicalResult LattigoDebugEmitter::printOperation(ModuleOp moduleOp) { + prelude = "package " + packageName + "\n"; + imports.insert("\"fmt\""); + imports.insert("\"strconv\""); + + imports.insert(std::string(kRlweImport)); + if (moduleIsBGVOrBFV(moduleOp)) { + imports.insert(std::string(kBgvImport)); + } else if (moduleIsCKKS(moduleOp)) { + imports.insert(std::string(kCkksImport)); + } else { + return moduleOp.emitError("Unknown scheme"); + } + + emitPrelude(); + + for (Operation& op : moduleOp) { + if (auto funcOp = dyn_cast(op)) { + if (failed(translate(op))) { + return failure(); + } + } + } + + return success(); +} + +LogicalResult LattigoDebugEmitter::printOperation(func::FuncOp funcOp) { + if (!isDebugPort(funcOp.getName()) || isEmitted) { + return success(); + } + + auto res = emitDebugHelperSignature( + funcOp, [&](Location loc, const std::string& message) { + return emitError(loc, message); + }); + + if (failed(res)) { + return res; + } + + os << " {\n"; + os.indent(); + res = emitDebugHelperImpl(); + if (failed(res)) { + return res; + } + os.unindent(); + os << "}\n"; + isEmitted = true; + return success(); +} + +LattigoDebugEmitter::LattigoDebugEmitter(raw_ostream& os, + const std::string& packageName) + : os(os), packageName(packageName), isEmitted(false) {} + +} // namespace lattigo +} // namespace heir +} // namespace mlir diff --git a/lib/Target/Lattigo/LattigoDebugEmitter.h b/lib/Target/Lattigo/LattigoDebugEmitter.h new file mode 100644 index 0000000000..dc9f465d39 --- /dev/null +++ b/lib/Target/Lattigo/LattigoDebugEmitter.h @@ -0,0 +1,76 @@ +#ifndef LIB_TARGET_LATTIGO_LATTIGODEBUGEMITTER_H_ +#define LIB_TARGET_LATTIGO_LATTIGODEBUGEMITTER_H_ + +#include +#include +#include +#include + +#include "lib/Dialect/Lattigo/IR/LattigoOps.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lattigo { + +using ErrorEmitterFn = std::function; + +/// Translates the given operation to Lattigo +::mlir::LogicalResult translateToDebugEmitter(::mlir::Operation* op, + llvm::raw_ostream& os, + const std::string& packageName); + +class LattigoDebugEmitter { + public: + LattigoDebugEmitter(raw_ostream& os, const std::string& packageName); + + LogicalResult translate(::mlir::Operation& operation); + + void emitPrelude() { + os << "package " << packageName << "\n"; + os << "import (\n"; + for (const auto& import : imports) { + os << " " << import << "\n"; + } + os << ")\n"; + os << "\n"; + } + + private: + /// Output stream to emit to. + raw_indented_ostream os; + + const std::string& packageName; + std::string prelude; + std::set imports; + + // Functions for printing individual ops + LogicalResult printOperation(::mlir::ModuleOp op); + LogicalResult printOperation(::mlir::func::FuncOp op); + + // Emit the default debug helper function signature + LogicalResult emitDebugHelperSignature(::mlir::func::FuncOp funcOp, + ErrorEmitterFn emitError); + + LogicalResult emitDebugHelperImpl(); + + FailureOr convertType(::mlir::Type type); + bool isEmitted; +}; + +} // namespace lattigo +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_LATTIGO_LATTIGODEBUGEMITTER_H_ diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index 5df6f78fa7..07d9e35af9 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -2234,78 +2234,6 @@ LattigoEmitter::LattigoEmitter(raw_ostream& os, extraImports(extraImports), funcFilter(funcFilter) {} -struct TranslateOptions { - llvm::cl::opt packageName{ - "package-name", - llvm::cl::desc("The name to use for the package declaration in the " - "generated golang file.")}; - llvm::cl::list extraImports{ - "extra-imports", llvm::cl::desc("Additional import paths")}; -}; -static llvm::ManagedStatic translateOptions; - -void registerTranslateOptions() { - // Forces initialization of options. - *translateOptions; -} - -void registerToLattigoTranslation() { - TranslateFromMLIRRegistration reg( - "emit-lattigo", - "translate the lattigo dialect to GO code against the Lattigo API", - [](Operation* op, llvm::raw_ostream& output) { - return translateToLattigo(op, output, translateOptions->packageName, - translateOptions->extraImports); - }, - [](DialectRegistry& registry) { - registry - .insert(); - }); -} - -void registerToLattigoPreprocessingTranslation() { - TranslateFromMLIRRegistration reg( - "emit-lattigo-preprocessing", - "translate the lattigo dialect to GO code against the Lattigo API", - [](Operation* op, llvm::raw_ostream& output) { - return translateToLattigo( - op, output, translateOptions->packageName, - translateOptions->extraImports, [](func::FuncOp funcOp) { - return funcOp->hasAttr(kClientPackFuncAttrName); - }); - }, - [](DialectRegistry& registry) { - registry - .insert(); - }); -} - -void registerToLattigoPreprocessedTranslation() { - TranslateFromMLIRRegistration reg( - "emit-lattigo-preprocessed", - "translate the lattigo dialect to GO code against the Lattigo API", - [](Operation* op, llvm::raw_ostream& output) { - return translateToLattigo( - op, output, translateOptions->packageName, - translateOptions->extraImports, [](func::FuncOp funcOp) { - return !funcOp->hasAttr(kClientPackFuncAttrName); - }); - }, - [](DialectRegistry& registry) { - registry - .insert(); - }); -} - } // namespace lattigo } // namespace heir } // namespace mlir diff --git a/lib/Target/Lattigo/LattigoEmitter.h b/lib/Target/Lattigo/LattigoEmitter.h index c634f54d6d..17132de3ec 100644 --- a/lib/Target/Lattigo/LattigoEmitter.h +++ b/lib/Target/Lattigo/LattigoEmitter.h @@ -31,8 +31,6 @@ namespace mlir { namespace heir { namespace lattigo { -void registerTranslateOptions(); - /// Translates the given operation to Lattigo ::mlir::LogicalResult translateToLattigo( ::mlir::Operation* op, llvm::raw_ostream& os, @@ -283,10 +281,6 @@ class LattigoEmitter { bool shouldDeclare = true); }; -void registerToLattigoTranslation(void); -void registerToLattigoPreprocessingTranslation(void); -void registerToLattigoPreprocessedTranslation(void); - } // namespace lattigo } // namespace heir } // namespace mlir diff --git a/lib/Target/Lattigo/LattigoTemplates.h b/lib/Target/Lattigo/LattigoTemplates.h index 4de1110822..53565a7db7 100644 --- a/lib/Target/Lattigo/LattigoTemplates.h +++ b/lib/Target/Lattigo/LattigoTemplates.h @@ -3,6 +3,8 @@ #include +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project + namespace mlir { namespace heir { namespace lattigo { @@ -33,6 +35,13 @@ constexpr std::string_view kMathImport = "\"math\""; constexpr std::string_view kSlicesImport = "\"slices\""; constexpr std::string_view kMathBigImport = "\"math/big\""; +inline constexpr llvm::StringLiteral kEvalVar = "evaluator"; +inline constexpr llvm::StringLiteral kEncodeVar = "encoder"; +inline constexpr llvm::StringLiteral kDecryptVar = "decryptor"; +inline constexpr llvm::StringLiteral kCiphertxtVar = "ct"; +inline constexpr llvm::StringLiteral kParamVar = "param"; +inline constexpr llvm::StringLiteral kDebugAttrMapParam = "debugAttrMap"; + } // namespace lattigo } // namespace heir } // namespace mlir diff --git a/lib/Target/Lattigo/LattigoTranslateRegistration.cpp b/lib/Target/Lattigo/LattigoTranslateRegistration.cpp new file mode 100644 index 0000000000..859d33af4a --- /dev/null +++ b/lib/Target/Lattigo/LattigoTranslateRegistration.cpp @@ -0,0 +1,126 @@ +#include "lib/Target/Lattigo/LattigoTranslateRegistration.h" + +#include "lib/Dialect/Lattigo/IR/LattigoDialect.h" +#include "lib/Dialect/Lattigo/IR/LattigoOps.h" +#include "lib/Dialect/Lattigo/IR/LattigoTypes.h" +#include "lib/Dialect/Mgmt/IR/MgmtDialect.h" +#include "lib/Dialect/ModuleAttributes.h" +#include "lib/Dialect/RNS/IR/RNSDialect.h" +#include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "lib/Target/Lattigo/LattigoDebugEmitter.h" +#include "lib/Target/Lattigo/LattigoEmitter.h" +#include "lib/Target/Lattigo/LattigoTemplates.h" +#include "lib/Utils/TargetUtils.h" +#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project +#include "llvm/include/llvm/Support/ManagedStatic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lattigo { + +struct TranslateOptions { + llvm::cl::opt packageName{ + "package-name", + llvm::cl::desc("The name to use for the package declaration in the " + "generated golang file."), + llvm::cl::init("main")}; + llvm::cl::list extraImports{ + "extra-imports", llvm::cl::desc("Additional import paths")}; +}; + +static llvm::ManagedStatic translateOptions; + +void registerTranslateOptions() { + // Forces initialization of options. + *translateOptions; +} + +void registerToLattigoTranslation() { + TranslateFromMLIRRegistration reg( + "emit-lattigo", + "translate the lattigo dialect to GO code against the Lattigo API", + [](Operation* op, llvm::raw_ostream& output) { + return translateToLattigo(op, output, translateOptions->packageName, + translateOptions->extraImports); + }, + [](DialectRegistry& registry) { + registry + .insert(); + }); +} + +void registerToLattigoDebugTranslation() { + TranslateFromMLIRRegistration reg( + "emit-lattigo-debug", + "Emit source code containing default debug helper implementation for " + "lattigo dialect", + [](Operation* op, llvm::raw_ostream& output) { + return translateToDebugEmitter(op, output, + translateOptions->packageName); + }, + [](DialectRegistry& registry) { + registry + .insert(); + }); +} + +void registerToLattigoPreprocessingTranslation() { + TranslateFromMLIRRegistration reg( + "emit-lattigo-preprocessing", + "translate the lattigo dialect to GO code against the Lattigo API", + [](Operation* op, llvm::raw_ostream& output) { + return translateToLattigo( + op, output, translateOptions->packageName, + translateOptions->extraImports, [](func::FuncOp funcOp) { + return funcOp->hasAttr(kClientPackFuncAttrName); + }); + }, + [](DialectRegistry& registry) { + registry + .insert(); + }); +} + +void registerToLattigoPreprocessedTranslation() { + TranslateFromMLIRRegistration reg( + "emit-lattigo-preprocessed", + "translate the lattigo dialect to GO code against the Lattigo API", + [](Operation* op, llvm::raw_ostream& output) { + return translateToLattigo( + op, output, translateOptions->packageName, + translateOptions->extraImports, [](func::FuncOp funcOp) { + return !funcOp->hasAttr(kClientPackFuncAttrName); + }); + }, + [](DialectRegistry& registry) { + registry + .insert(); + }); +} + +} // namespace lattigo +} // namespace heir +} // namespace mlir diff --git a/lib/Target/Lattigo/LattigoTranslateRegistration.h b/lib/Target/Lattigo/LattigoTranslateRegistration.h new file mode 100644 index 0000000000..d1e146f0bb --- /dev/null +++ b/lib/Target/Lattigo/LattigoTranslateRegistration.h @@ -0,0 +1,22 @@ +#ifndef LIB_TARGET_LATTIGO_LATTIGOTRANSLATEREGISTRATION_H_ +#define LIB_TARGET_LATTIGO_LATTIGOTRANSLATEREGISTRATION_H_ + +namespace mlir { +namespace heir { +namespace lattigo { + +void registerTranslateOptions(); + +void registerToLattigoTranslation(); + +void registerToLattigoPreprocessingTranslation(); + +void registerToLattigoPreprocessedTranslation(); + +void registerToLattigoDebugTranslation(); + +} // namespace lattigo +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_LATTIGO_LATTIGOTRANSLATEREGISTRATION_H_ diff --git a/tests/Emitter/Lattigo/emit_debug_helper.mlir b/tests/Emitter/Lattigo/emit_debug_helper.mlir new file mode 100644 index 0000000000..04ec1bcef1 --- /dev/null +++ b/tests/Emitter/Lattigo/emit_debug_helper.mlir @@ -0,0 +1,43 @@ +// RUN: heir-translate %s --emit-lattigo-debug | FileCheck %s + +!ct = !lattigo.rlwe.ciphertext +!encryptor = !lattigo.rlwe.encryptor +!decryptor = !lattigo.rlwe.decryptor +!evaluator = !lattigo.bgv.evaluator +!encoder = !lattigo.bgv.encoder +!params = !lattigo.bgv.parameter + + +// CHECK: package [[package:.*]] +// CHECK: import ( +// CHECK: "fmt" +// CHECK: "github.com/tuneinsight/lattigo/v6/core/rlwe" +// CHECK: "github.com/tuneinsight/lattigo/v6/schemes/bgv" +// CHECK: "strconv" +// CHECK: ) +// CHECK: func __heir_debug( +// CHECK-SAME: [[eval:[^ ]+]] *bgv.Evaluator, +// CHECK-SAME: [[param:[^ ]+]] bgv.Parameters, +// CHECK-SAME: [[encoder:[^ ]+]] *bgv.Encoder, +// CHECK-SAME: [[decryptor:[^ ]+]] *rlwe.Decryptor, +// CHECK-SAME: [[ct:[^ ]+]] *rlwe.Ciphertext, +// CHECK-SAME: [[m:debugAttrMap]] map[string]string) { +// CHECK: isBlockArgument := [[m]]["asm.is_block_arg"] +// CHECK: if isBlockArgument == "1" { +// CHECK: fmt.Println("Input") +// CHECK: } else { +// CHECK: fmt.Println([[m]]["asm.op_name"]) +// CHECK: } +// CHECK: messageSize, _ := strconv.Atoi([[m]]["message.size"]) +// CHECK: value := make([]int64, messageSize) +// CHECK: pt := [[decryptor]].DecryptNew([[ct]]) +// CHECK: [[encoder]].Decode(pt, value) +// CHECK: fmt.Printf(" %v\n", value) +// CHECK: } +module attributes {scheme.bgv} { + func.func private @__heir_debug_0(!lattigo.bgv.evaluator, !lattigo.bgv.parameter, !lattigo.bgv.encoder, !lattigo.rlwe.decryptor, !lattigo.rlwe.ciphertext) + func.func @dot_product(%evaluator: !lattigo.bgv.evaluator, %param: !lattigo.bgv.parameter, %encoder: !lattigo.bgv.encoder, %decryptor: !lattigo.rlwe.decryptor, %ct: !lattigo.rlwe.ciphertext, %ct_0: !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext { + call @__heir_debug_0(%evaluator, %param, %encoder, %decryptor, %ct) {bound = "50", random = 3, complex = {test = 1.2}, secret.secret} : (!lattigo.bgv.evaluator, !lattigo.bgv.parameter, !lattigo.bgv.encoder, !lattigo.rlwe.decryptor, !lattigo.rlwe.ciphertext) -> () + return %ct : !lattigo.rlwe.ciphertext + } +} diff --git a/tools/BUILD b/tools/BUILD index 34e037e5bf..0c5c22c871 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -229,7 +229,7 @@ cc_binary( "@heir//lib/Target/FunctionInfo:FunctionInfoEmitter", "@heir//lib/Target/Jaxite:JaxiteEmitter", "@heir//lib/Target/JaxiteWord:JaxiteWordEmitter", - "@heir//lib/Target/Lattigo:LattigoEmitter", + "@heir//lib/Target/Lattigo:LattigoRegistration", "@heir//lib/Target/Metadata:MetadataEmitter", "@heir//lib/Target/OpenFhePke:OpenFheRegistration", "@heir//lib/Target/SCIFRBool:SCIFRBoolEmitter", diff --git a/tools/heir-translate.cpp b/tools/heir-translate.cpp index c4e73daae8..eec5151cd7 100644 --- a/tools/heir-translate.cpp +++ b/tools/heir-translate.cpp @@ -1,7 +1,7 @@ #include "lib/Target/FunctionInfo/FunctionInfoEmitter.h" #include "lib/Target/Jaxite/JaxiteEmitter.h" #include "lib/Target/JaxiteWord/JaxiteWordEmitter.h" -#include "lib/Target/Lattigo/LattigoEmitter.h" +#include "lib/Target/Lattigo/LattigoTranslateRegistration.h" #include "lib/Target/Metadata/MetadataEmitter.h" #include "lib/Target/OpenFhePke/OpenFheTranslateRegistration.h" // This comment includes internal emitters @@ -41,10 +41,11 @@ int main(int argc, char** argv) { mlir::heir::openfhe::registerToOpenFhePkeDebugTranslation(); // Lattigo + mlir::heir::lattigo::registerTranslateOptions(); mlir::heir::lattigo::registerToLattigoTranslation(); mlir::heir::lattigo::registerToLattigoPreprocessingTranslation(); mlir::heir::lattigo::registerToLattigoPreprocessedTranslation(); - mlir::heir::lattigo::registerTranslateOptions(); + mlir::heir::lattigo::registerToLattigoDebugTranslation(); // SCIFRBool mlir::cornami::scifrbool::registerTranslateOptions();