diff --git a/BUILD.bazel b/BUILD.bazel index 4474da4440..6958f5383b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -70,6 +70,21 @@ config_setting( }, ) +# CHEDDAR GPU FHE backend (opt-in, requires CUDA) +# use by passing `--//:enable_cheddar=1` to `bazel build` +string_flag( + name = "enable_cheddar", + build_setting_default = "0", +) + +config_setting( + name = "config_enable_cheddar", + flag_values = { + ":enable_cheddar": "1", + }, + visibility = ["//visibility:public"], +) + # OpenFHE interpreter string_flag( name = "openfhe_enable_timing", diff --git a/MODULE.bazel b/MODULE.bazel index b15bd4d404..3ddb488832 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -25,6 +25,8 @@ bazel_dep(name = "bazel_skylib_gazelle_plugin", version = "1.8.2", dev_dependenc # implicitly depends upon the target '//:license'. How bizarre. bazel_dep(name = "rules_license", version = "1.0.0") bazel_dep(name = "platforms", version = "1.0.0") +bazel_dep(name = "rules_cuda", version = "0.3.0") +bazel_dep(name = "rules_foreign_cc", version = "0.15.1") bazel_dep(name = "rules_go", version = "0.53.0") bazel_dep(name = "rules_python", version = "1.5.1") bazel_dep(name = "googletest", version = "1.17.0") @@ -64,6 +66,18 @@ use_repo( "llvm_zstd", ) +# CHEDDAR GPU FHE library (opt-in, requires CUDA) +cheddar_extensions = use_extension("//bazel:extensions.bzl", "cheddar_deps") +use_repo(cheddar_extensions, "cheddar") + +# CUDA toolkit (for CHEDDAR and other GPU backends) +cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") +cuda.toolkit( + name = "cuda", + toolkit_path = "", +) +use_repo(cuda, "cuda") + # The subset of LLVM backend targets that should be compiled _LLVM_TARGETS = [ "X86", diff --git a/bazel/cheddar/BUILD b/bazel/cheddar/BUILD new file mode 100644 index 0000000000..60c4274c22 --- /dev/null +++ b/bazel/cheddar/BUILD @@ -0,0 +1,7 @@ +# This build file is necessary to mark this directory as a subpackage for bazel +# to have access to the files. + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) diff --git a/bazel/cheddar/cheddar.BUILD b/bazel/cheddar/cheddar.BUILD new file mode 100644 index 0000000000..4cb81ba90f --- /dev/null +++ b/bazel/cheddar/cheddar.BUILD @@ -0,0 +1,56 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake") + +package( + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "all_srcs", + srcs = glob( + ["**"], + exclude = [ + "bazel-*/**", + "build/**", + "cmake-build*/**", + ], + ), +) + +cmake( + name = "cheddar_cmake", + cache_entries = { + "CMAKE_BUILD_TYPE": "Release", + "BUILD_UNITTEST": "OFF", + "ENABLE_EXTENSION": "ON", + "USE_GMP": "OFF", + # Build for the local GPU architecture to keep build times manageable. + "CMAKE_CUDA_ARCHITECTURES": "native", + # Assumes conventional CUDA toolkit location and configuration. + "CMAKE_CUDA_COMPILER": "/usr/local/cuda/bin/nvcc", + "CUDAToolkit_ROOT": "/usr/local/cuda", + "CMAKE_CUDA_HOST_COMPILER:FILEPATH": "/usr/bin/g++", + "CMAKE_CUDA_FLAGS": "-ccbin=/usr/bin/g++", + # Wipe toolchain-provided linker flags that can inject unsupported host options. + "CMAKE_EXE_LINKER_FLAGS": "", + "CMAKE_MODULE_LINKER_FLAGS": "", + "CMAKE_SHARED_LINKER_FLAGS": "", + }, + generate_crosstool_file = False, + lib_source = ":all_srcs", + out_include_dir = "include", + out_shared_libs = [ + "libcheddar.so", + ], + targets = ["cheddar"], +) + +# Wrap the cmake output with CUDA Thrust headers so downstream consumers +# can resolve cheddar's public #include directives. +cc_library( + name = "cheddar", + deps = [ + ":cheddar_cmake", + "@cuda//:thrust", + ], +) diff --git a/bazel/cheddar/config.bzl b/bazel/cheddar/config.bzl new file mode 100644 index 0000000000..5da4eb0e44 --- /dev/null +++ b/bazel/cheddar/config.bzl @@ -0,0 +1,21 @@ +"""Helper macros for CHEDDAR opt-in build configuration.""" + +def if_cheddar_enabled(if_true, if_false = []): + """Select based on whether CHEDDAR is enabled.""" + return select({ + "@heir//:config_enable_cheddar": if_true, + "//conditions:default": if_false, + }) + +def requires_cheddar(): + """Returns target_compatible_with for CHEDDAR-requiring targets.""" + return select({ + "@heir//:config_enable_cheddar": [], + "//conditions:default": ["@platforms//:incompatible"], + }) + +def cheddar_deps(extra = []): + """Returns CHEDDAR library deps, empty when disabled.""" + return if_cheddar_enabled( + ["@cheddar//:cheddar"] + extra, + ) diff --git a/bazel/extensions.bzl b/bazel/extensions.bzl index fedd9b2e70..4779ec0b52 100644 --- a/bazel/extensions.bzl +++ b/bazel/extensions.bzl @@ -53,3 +53,19 @@ def _llvm_deps_impl(_): llvm_deps = module_extension( implementation = _llvm_deps_impl, ) + +# CHEDDAR GPU FHE library +CHEDDAR_COMMIT = "307b49cbe03e7f8f14bf31485f716c1090c9ec9d" + +def _cheddar_deps_impl(_): + maybe( + new_git_repository, + name = "cheddar", + build_file = "@heir//bazel/cheddar:cheddar.BUILD", + commit = CHEDDAR_COMMIT, + remote = "https://github.com/scale-snu/cheddar-fhe.git", + patches = ["@heir//patches:cheddar.patch"], + patch_args = ["-p1"], + ) + +cheddar_deps = module_extension(implementation = _cheddar_deps_impl) diff --git a/lib/Dialect/Cheddar/IR/BUILD b/lib/Dialect/Cheddar/IR/BUILD new file mode 100644 index 0000000000..47e2dd0482 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/BUILD @@ -0,0 +1,156 @@ +# Cheddar dialect implementation + +load("@heir//lib/Dialect:dialect.bzl", "add_heir_dialect_library") +load("@llvm-project//mlir:tblgen.bzl", "td_library") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = [ + "CheddarDialect.cpp", + ], + hdrs = [ + "CheddarAttributes.h", + "CheddarDialect.h", + "CheddarOps.h", + "CheddarTypes.h", + ], + deps = [ + ":CheddarAttributes", + ":CheddarOps", + ":CheddarTypes", + ":attributes_inc_gen", + ":dialect_inc_gen", + ":ops_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) + +cc_library( + name = "CheddarAttributes", + srcs = [ + "CheddarAttributes.cpp", + ], + hdrs = [ + "CheddarAttributes.h", + "CheddarDialect.h", + ], + deps = [ + ":attributes_inc_gen", + ":dialect_inc_gen", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "CheddarTypes", + srcs = [ + "CheddarTypes.cpp", + ], + hdrs = [ + "CheddarAttributes.h", + "CheddarDialect.h", + "CheddarTypes.h", + ], + deps = [ + ":CheddarAttributes", + ":attributes_inc_gen", + ":dialect_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "CheddarOps", + srcs = [ + "CheddarOps.cpp", + ], + hdrs = [ + "CheddarDialect.h", + "CheddarOps.h", + "CheddarTypes.h", + ], + deps = [ + ":CheddarAttributes", + ":CheddarTypes", + ":dialect_inc_gen", + ":ops_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", + "@heir//lib/Utils", + "@heir//lib/Utils:RotationUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", + ], +) + +td_library( + name = "td_files", + srcs = [ + "CheddarAttributes.td", + "CheddarDialect.td", + "CheddarOps.td", + "CheddarTypes.td", + ], + # include from the heir-root to enable fully-qualified include-paths + includes = ["../../../.."], + deps = [ + "@heir//lib/Dialect:td_files", + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +add_heir_dialect_library( + name = "dialect_inc_gen", + dialect = "Cheddar", + kind = "dialect", + td_file = "CheddarDialect.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "attributes_inc_gen", + dialect = "Cheddar", + kind = "attribute", + td_file = "CheddarAttributes.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "types_inc_gen", + dialect = "Cheddar", + kind = "type", + td_file = "CheddarTypes.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "ops_inc_gen", + dialect = "Cheddar", + kind = "op", + td_file = "CheddarOps.td", + deps = [ + ":td_files", + "@heir//lib/Dialect:td_files", + ], +) diff --git a/lib/Dialect/Cheddar/IR/CheddarAttributes.cpp b/lib/Dialect/Cheddar/IR/CheddarAttributes.cpp new file mode 100644 index 0000000000..854bff16e8 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarAttributes.cpp @@ -0,0 +1,7 @@ +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.h" + +namespace mlir { +namespace heir { +namespace cheddar {} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarAttributes.h b/lib/Dialect/Cheddar/IR/CheddarAttributes.h new file mode 100644 index 0000000000..82903f4d83 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarAttributes.h @@ -0,0 +1,9 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARATTRIBUTES_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARATTRIBUTES_H_ + +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" + +#define GET_ATTRDEF_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARATTRIBUTES_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarAttributes.td b/lib/Dialect/Cheddar/IR/CheddarAttributes.td new file mode 100644 index 0000000000..916bbc6368 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarAttributes.td @@ -0,0 +1,30 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARATTRIBUTES_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARATTRIBUTES_TD_ + +include "CheddarDialect.td" + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" + +class Cheddar_Attribute + : AttrDef { + let mnemonic = attrMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def Cheddar_ParameterFileAttr + : Cheddar_Attribute<"ParameterFile", "parameter_file"> { + let summary = "Path to a CHEDDAR parameter JSON file"; + let description = [{ + This attribute holds the path to a CHEDDAR parameter JSON file. + + CHEDDAR parameters are loaded from pre-built JSON files that specify + ring dimension, prime chains, scale, and other CKKS parameters. + }]; + let parameters = (ins + StringRefParameter<"path to JSON parameter file">:$path, + DefaultValuedParameter<"bool", "false", "use 64-bit word type">:$use64Bit + ); +} + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARATTRIBUTES_TD_ diff --git a/lib/Dialect/Cheddar/IR/CheddarDialect.cpp b/lib/Dialect/Cheddar/IR/CheddarDialect.cpp new file mode 100644 index 0000000000..964a416388 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarDialect.cpp @@ -0,0 +1,48 @@ +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" + +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project + +// NOLINTNEXTLINE(misc-include-cleaner): Required to define CheddarOps + +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.h" +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" + +// Generated definitions +#include "lib/Dialect/Cheddar/IR/CheddarDialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc" + +namespace mlir { +namespace heir { +namespace cheddar { + +void CheddarDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.cpp.inc" + >(); + + addTypes< +#define GET_TYPEDEF_LIST +#include "lib/Dialect/Cheddar/IR/CheddarTypes.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc" + >(); +} + +} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarDialect.h b/lib/Dialect/Cheddar/IR/CheddarDialect.h new file mode 100644 index 0000000000..555fe2ad55 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarDialect.h @@ -0,0 +1,10 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_ + +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project + +// Generated headers (block clang-format from messing up order) +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarDialect.td b/lib/Dialect/Cheddar/IR/CheddarDialect.td new file mode 100644 index 0000000000..32f206c876 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarDialect.td @@ -0,0 +1,25 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_ + +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" + +def Cheddar_Dialect : Dialect { + let name = "cheddar"; + let description = [{ + The `cheddar` dialect is an exit dialect for generating C++ code against the + CHEDDAR GPU FHE library API. + + CHEDDAR is a CKKS-only GPU-accelerated FHE library. It supports both 32-bit + and 64-bit word types, with 32-bit being the primary fast path on GPUs. + + See [the Cheddar GitHub repository](https://github.com/scale-snu/cheddar-fhe) + }]; + + let cppNamespace = "::mlir::heir::cheddar"; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; +} + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARDIALECT_TD_ diff --git a/lib/Dialect/Cheddar/IR/CheddarOps.cpp b/lib/Dialect/Cheddar/IR/CheddarOps.cpp new file mode 100644 index 0000000000..5f39af0c8a --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarOps.cpp @@ -0,0 +1,46 @@ +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" + +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Utils/RotationUtils.h" +#include "lib/Utils/Utils.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace cheddar { + +::llvm::SmallVector<::mlir::OpFoldResult> HRotOp::getRotationIndices() { + if (getStaticShift()) return {getStaticShiftAttr()}; + return {getDynamicShift()}; +} + +LogicalResult HRotOp::verify() { + return containsExactlyOneOrEmitError(getOperation(), getDynamicShift(), + getStaticShift()); +} + +::llvm::SmallVector<::mlir::OpFoldResult> HRotAddOp::getRotationIndices() { + return {getDistanceAttr()}; +} + +::llvm::SmallVector<::mlir::OpFoldResult> +LinearTransformOp::getRotationIndices() { + auto diagonalsType = cast(getDiagonals().getType()); + int64_t slots = diagonalsType.getShape()[1]; + int64_t logBSGS = getLogBabyStepGiantStepRatio().getInt(); + auto rotations = lintransRotationIndices( + getDiagonalIndicesAttr().asArrayRef(), slots, logBSGS); + SmallVector result; + result.reserve(rotations.size()); + auto* mlirCtx = (*this)->getContext(); + for (int64_t rot : rotations) { + result.push_back(IntegerAttr::get(IndexType::get(mlirCtx), rot)); + } + return result; +} + +} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarOps.h b/lib/Dialect/Cheddar/IR/CheddarOps.h new file mode 100644 index 0000000000..6a96186196 --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarOps.h @@ -0,0 +1,26 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_ + +#include + +// IWYU pragma: begin_keep +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Dialect/HEIRInterfaces.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +// IWYU pragma: end_keep + +namespace mlir::heir::cheddar { + +// `GetRotKeyOp` stores a static distance attribute, but dynamic +// `ckks.rotate` lowering still needs a placeholder key op so the emitter can +// trace back to the `UserInterface`. This sentinel distance marks that case. +constexpr int64_t kDynamicRotationKeyDistanceSentinel = -1; + +} // namespace mlir::heir::cheddar + +#define GET_OP_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarOps.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarOps.td b/lib/Dialect/Cheddar/IR/CheddarOps.td new file mode 100644 index 0000000000..08f55d107c --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarOps.td @@ -0,0 +1,497 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_TD_ + +include "CheddarDialect.td" +include "CheddarTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "lib/Dialect/HEIRInterfaces.td" + +class Cheddar_Op traits = []> : + Op { + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Setup operations +//===----------------------------------------------------------------------===// + +def Cheddar_CreateContextOp : Cheddar_Op<"create_context"> { + let summary = "Create a CHEDDAR context from parameters"; + let description = [{ + Creates a CHEDDAR Context from a Parameter object. + The context is the main server-side computation engine. + }]; + let arguments = (ins Cheddar_Parameter:$params); + let results = (outs Cheddar_Context:$ctx); +} + +def Cheddar_CreateUserInterfaceOp : Cheddar_Op<"create_user_interface"> { + let summary = "Create a CHEDDAR UserInterface for key gen and encrypt/decrypt"; + let description = [{ + Creates a UserInterface from a context. The UserInterface handles + key generation, encryption, and decryption. Note: for test purposes only. + }]; + let arguments = (ins Cheddar_Context:$ctx); + let results = (outs Cheddar_UserInterface:$ui); +} + +def Cheddar_GetEncoderOp : Cheddar_Op<"get_encoder"> { + let summary = "Get the encoder from a CHEDDAR context"; + let description = [{ + Returns a reference to the context's encoder (context->encoder_). + }]; + let arguments = (ins Cheddar_Context:$ctx); + let results = (outs Cheddar_Encoder:$encoder); +} + +def Cheddar_GetEvkMapOp : Cheddar_Op<"get_evk_map"> { + let summary = "Get the evaluation key map from a UserInterface"; + let description = [{ + Returns the EvkMap from the UserInterface (ui.GetEvkMap()). + }]; + let arguments = (ins Cheddar_UserInterface:$ui); + let results = (outs Cheddar_EvkMap:$evkMap); +} + +def Cheddar_GetMultKeyOp : Cheddar_Op<"get_mult_key"> { + let summary = "Get the multiplication evaluation key"; + let description = [{ + Returns the multiplication key from the UserInterface + (ui.GetMultiplicationKey()). + }]; + let arguments = (ins Cheddar_UserInterface:$ui); + let results = (outs Cheddar_EvalKey:$key); +} + +def Cheddar_GetRotKeyOp : Cheddar_Op<"get_rot_key"> { + let summary = "Get a rotation evaluation key"; + let description = [{ + Returns a rotation key for the given distance from the UserInterface + (ui.GetRotationKey(dist)). + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Builtin_IntegerAttr:$distance + ); + let results = (outs Cheddar_EvalKey:$key); +} + +def Cheddar_GetConjKeyOp : Cheddar_Op<"get_conj_key"> { + let summary = "Get the conjugation evaluation key"; + let arguments = (ins Cheddar_UserInterface:$ui); + let results = (outs Cheddar_EvalKey:$key); +} + +def Cheddar_PrepareRotKeyOp : Cheddar_Op<"prepare_rot_key"> { + let summary = "Generate a rotation key for a given distance"; + let description = [{ + Calls ui.PrepareRotationKey(distance, max_level) to generate a rotation key. + Must be called before using rotation with that distance. + The max_level parameter specifies the maximum ciphertext level at which + this rotation key will be used. + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Builtin_IntegerAttr:$distance, + Builtin_IntegerAttr:$maxLevel + ); + let results = (outs); +} + +//===----------------------------------------------------------------------===// +// Encode / Encrypt / Decrypt operations +//===----------------------------------------------------------------------===// + +def Cheddar_EncodeOp : Cheddar_Op<"encode"> { + let summary = "Encode a message vector into a CHEDDAR plaintext"; + let description = [{ + Calls encoder.Encode(pt, level, scale, message). The message is a vector + of complex numbers (or reals). + }]; + let arguments = (ins + Cheddar_Encoder:$encoder, + RankedTensorOf<[AnyFloat, AnyComplex]>:$message, + Builtin_IntegerAttr:$level, + Builtin_IntegerAttr:$scale + ); + let results = (outs Cheddar_Plaintext:$plaintext); +} + +def Cheddar_EncodeConstantOp : Cheddar_Op<"encode_constant"> { + let summary = "Encode a scalar double into a CHEDDAR constant"; + let description = [{ + Calls encoder.EncodeConstant(constant, level, scale, number). + The result is in RNS form for efficient ciphertext-scalar ops. + }]; + let arguments = (ins + Cheddar_Encoder:$encoder, + AnyFloat:$value, + Builtin_IntegerAttr:$level, + Builtin_IntegerAttr:$scale + ); + let results = (outs Cheddar_Constant:$constant); +} + +def Cheddar_DecodeOp : Cheddar_Op<"decode"> { + let summary = "Decode a CHEDDAR plaintext back to a message vector"; + let description = [{ + Calls encoder.Decode(message, pt). Returns a vector of complex numbers. + }]; + let arguments = (ins + Cheddar_Encoder:$encoder, + Cheddar_Plaintext:$plaintext + ); + let results = (outs RankedTensorOf<[AnyFloat, AnyComplex]>:$message); +} + +def Cheddar_EncryptOp : Cheddar_Op<"encrypt"> { + let summary = "Encrypt a plaintext into a ciphertext"; + let description = [{ + Calls ui.Encrypt(ct, pt). Test-only operation. + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Cheddar_Plaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$ciphertext); +} + +def Cheddar_DecryptOp : Cheddar_Op<"decrypt"> { + let summary = "Decrypt a ciphertext into a plaintext"; + let description = [{ + Calls ui.Decrypt(pt, ct). Test-only operation. + }]; + let arguments = (ins + Cheddar_UserInterface:$ui, + Cheddar_Ciphertext:$ciphertext + ); + let results = (outs Cheddar_Plaintext:$plaintext); +} + +//===----------------------------------------------------------------------===// +// Ciphertext-ciphertext arithmetic operations +//===----------------------------------------------------------------------===// + +class Cheddar_BinaryCtCtOp traits = []> + : Cheddar_Op { + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$lhs, + Cheddar_Ciphertext:$rhs + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_AddOp : Cheddar_BinaryCtCtOp<"add"> { + let summary = "Add two ciphertexts"; + let description = [{ + Calls context->Add(res, a, b). + }]; +} + +def Cheddar_SubOp : Cheddar_BinaryCtCtOp<"sub"> { + let summary = "Subtract two ciphertexts"; + let description = [{ + Calls context->Sub(res, a, b). + }]; +} + +def Cheddar_MultOp : Cheddar_BinaryCtCtOp<"mult", [IncreasesMulDepthOpInterface]> { + let summary = "Multiply two ciphertexts (tensor product, no relin/rescale)"; + let description = [{ + Calls context->Mult(res, a, b). Produces a degree-3 ciphertext. + Does NOT include relinearization or rescaling. + }]; +} + +//===----------------------------------------------------------------------===// +// Ciphertext-plaintext / ciphertext-constant operations +//===----------------------------------------------------------------------===// + +def Cheddar_AddPlainOp : Cheddar_Op<"add_plain"> { + let summary = "Add a plaintext to a ciphertext"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_CiphertextOrPlaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_SubPlainOp : Cheddar_Op<"sub_plain"> { + let summary = "Subtract a plaintext from a ciphertext"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_CiphertextOrPlaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_MultPlainOp : Cheddar_Op<"mult_plain"> { + let summary = "Multiply a ciphertext by a plaintext (no rescale)"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_CiphertextOrPlaintext:$plaintext + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_AddConstOp : Cheddar_Op<"add_const"> { + let summary = "Add a constant to a ciphertext"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Constant:$constant + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_MultConstOp : Cheddar_Op<"mult_const"> { + let summary = "Multiply a ciphertext by a constant (no rescale)"; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$ciphertext, + Cheddar_Constant:$constant + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Unary ciphertext operations +//===----------------------------------------------------------------------===// + +def Cheddar_NegOp : Cheddar_Op<"neg"> { + let summary = "Negate a ciphertext"; + let description = [{ + Calls context->Neg(res, a). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_RescaleOp : Cheddar_Op<"rescale"> { + let summary = "Rescale a ciphertext (drop one level)"; + let description = [{ + Calls context->Rescale(res, a). Reduces the ciphertext level by 1. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_LevelDownOp : Cheddar_Op<"level_down"> { + let summary = "Reduce ciphertext to a target level"; + let description = [{ + Calls context->LevelDown(res, a, target_level). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Builtin_IntegerAttr:$targetLevel + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Key-switching operations +//===----------------------------------------------------------------------===// + +def Cheddar_RelinearizeOp : Cheddar_Op<"relinearize"> { + let summary = "Relinearize a ciphertext (without rescale)"; + let description = [{ + Calls context->Relinearize(res, a, mult_key). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvalKey:$multKey + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_RelinearizeRescaleOp : Cheddar_Op<"relinearize_rescale"> { + let summary = "Fused relinearize + rescale"; + let description = [{ + Calls context->RelinearizeRescale(res, a, mult_key). Faster than + separate Relinearize + Rescale. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvalKey:$multKey + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Compound (fused) operations -- high-performance GPU kernels +//===----------------------------------------------------------------------===// + +def Cheddar_HMultOp : Cheddar_Op<"hmult", [IncreasesMulDepthOpInterface]> { + let summary = "Fused multiply + relinearize (+ optional rescale)"; + let description = [{ + Calls context->HMult(res, a, b, mult_key, rescale). + Single fused GPU kernel launch. The `rescale` attribute controls whether + rescaling is included (default: true). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$lhs, + Cheddar_Ciphertext:$rhs, + Cheddar_EvalKey:$multKey, + DefaultValuedAttr:$rescale + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_HRotOp : Cheddar_Op<"hrot", [ + DeclareOpInterfaceMethods +]> { + let summary = "Fused key-switch + rotation"; + let description = [{ + Calls context->HRot(res, a, rot_key, distance). + Single fused GPU kernel for rotation. + Supports both static and dynamic shift distances. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvalKey:$rotKey, + Optional:$dynamic_shift, + OptionalAttr:$static_shift + ); + let results = (outs Cheddar_Ciphertext:$output); + let hasVerifier = 1; +} + +def Cheddar_HRotAddOp : Cheddar_Op<"hrot_add", [ + DeclareOpInterfaceMethods +]> { + let summary = "Fused rotation + addition"; + let description = [{ + Computes res = rotate(a, distance) + b in a single fused GPU kernel. + Calls context->HRotAdd(res, a, b, rot_key, distance). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_Ciphertext:$addend, + Cheddar_EvalKey:$rotKey, + Builtin_IntegerAttr:$distance + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_HConjOp : Cheddar_Op<"hconj"> { + let summary = "Fused key-switch + conjugation"; + let description = [{ + Calls context->HConj(res, a, conj_key). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvalKey:$conjKey + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_HConjAddOp : Cheddar_Op<"hconj_add"> { + let summary = "Fused conjugation + addition"; + let description = [{ + Computes res = conj(a) + b in a single fused GPU kernel. + Calls context->HConjAdd(res, a, b, conj_key). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_Ciphertext:$addend, + Cheddar_EvalKey:$conjKey + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_MadUnsafeOp : Cheddar_Op<"mad_unsafe"> { + let summary = "Fused multiply-accumulate with constant (no rescale)"; + let description = [{ + Computes res += a * constant (in-place accumulation). + Calls context->MadUnsafe(res, a, constant). + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$accumulator, + Cheddar_Ciphertext:$input, + Cheddar_Constant:$constant + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +//===----------------------------------------------------------------------===// +// Extension operations (bootstrapping, linear transforms, poly eval) +//===----------------------------------------------------------------------===// + +def Cheddar_BootOp : Cheddar_Op<"boot", [ResetsMulDepthOpInterface]> { + let summary = "Bootstrap a ciphertext"; + let description = [{ + Calls boot_ctx->Boot(res, input, evk_map). + Refreshes the ciphertext noise budget. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvkMap:$evkMap + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_LinearTransformOp : Cheddar_Op<"linear_transform", [ + DeclareOpInterfaceMethods +]> { + let summary = "Apply a linear transform on a ciphertext"; + let description = [{ + Applies a matrix-vector product using CHEDDAR's LinearTransform extension + with BSGS optimization and hoisting. + + The `diagonals` input is a 2D tensor where each row is a non-zero diagonal. + The `diagonal_indices` attribute specifies which diagonal each row represents. + The `level` attribute specifies the modulus level for the operation. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvkMap:$evkMap, + 2DTensorOf<[AnyFloat]>:$diagonals, + DenseI32ArrayAttr:$diagonal_indices, + Builtin_IntegerAttr:$level, + Builtin_IntegerAttr:$logBabyStepGiantStepRatio + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +def Cheddar_EvalPolyOp : Cheddar_Op<"eval_poly"> { + let summary = "Evaluate a polynomial on a ciphertext"; + let description = [{ + Evaluates a polynomial (e.g., Chebyshev approximation) on an encrypted + input using CHEDDAR's EvalPoly extension. + }]; + let arguments = (ins + Cheddar_Context:$ctx, + Cheddar_Ciphertext:$input, + Cheddar_EvkMap:$evkMap, + ArrayAttr:$coefficients, + Builtin_IntegerAttr:$level + ); + let results = (outs Cheddar_Ciphertext:$output); +} + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDAROPS_TD_ diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.cpp b/lib/Dialect/Cheddar/IR/CheddarTypes.cpp new file mode 100644 index 0000000000..707851f21a --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.cpp @@ -0,0 +1,7 @@ +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" + +namespace mlir { +namespace heir { +namespace cheddar {} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.h b/lib/Dialect/Cheddar/IR/CheddarTypes.h new file mode 100644 index 0000000000..dc4191c8dc --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.h @@ -0,0 +1,14 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_H_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_H_ + +// IWYU pragma: begin_keep +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/HEIRInterfaces.h" +#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project +// IWYU pragma: end_keep + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h.inc" + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_H_ diff --git a/lib/Dialect/Cheddar/IR/CheddarTypes.td b/lib/Dialect/Cheddar/IR/CheddarTypes.td new file mode 100644 index 0000000000..2663f3fa4d --- /dev/null +++ b/lib/Dialect/Cheddar/IR/CheddarTypes.td @@ -0,0 +1,116 @@ +#ifndef LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_TD_ +#define LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_TD_ + +include "CheddarDialect.td" +include "CheddarAttributes.td" + +include "lib/Dialect/HEIRInterfaces.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/OpAsmInterface.td" + +// A base class for all types in this dialect +class Cheddar_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; + + string asmName = ?; + string aliasSuffix = ""; + let extraClassDeclaration = [{ + // OpAsmTypeInterface method + void getAsmName(::mlir::OpAsmSetNameFn setNameFn) const { + setNameFn("}] # asmName # [{"); + } + + ::mlir::OpAsmDialectInterface::AliasResult getAlias(::llvm::raw_ostream &os) const { + os << "}] # asmName # [{"; + }] # aliasSuffix # [{ + return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; + } + }]; +} + +// Context types + +def Cheddar_Context : Cheddar_Type<"Context", "context"> { + let description = [{ + This type represents a CHEDDAR Context (or BootContext), which is the + main server-side computation engine. Created via Context::Create(param) + or BootContext::Create(param, boot_param). + }]; + let asmName = "ctx"; +} + +def Cheddar_Parameter : Cheddar_Type<"Parameter", "parameter"> { + let description = [{ + This type represents a CHEDDAR Parameter object, constructed from a JSON + file or programmatically. + }]; + let asmName = "param"; +} + +def Cheddar_Encoder : Cheddar_Type<"Encoder", "encoder"> { + let description = [{ + This type represents the CHEDDAR Encoder, accessed via context->encoder_. + Used for encoding/decoding plaintext values. + }]; + let asmName = "encoder"; +} + +def Cheddar_UserInterface : Cheddar_Type<"UserInterface", "user_interface"> { + let description = [{ + This type represents the CHEDDAR UserInterface, used for key generation, + encryption, and decryption. Note: this is for test purposes only and is + not security-hardened. + }]; + let asmName = "ui"; +} + +// Data types + +def Cheddar_Ciphertext : Cheddar_Type<"Ciphertext", "ciphertext", [SecretTypeInterface]> { + let description = [{ + This type represents a CHEDDAR Ciphertext. Move-only, lives on GPU. + }]; + let asmName = "ct"; +} + +def Cheddar_Plaintext : Cheddar_Type<"Plaintext", "plaintext"> { + let description = [{ + This type represents a CHEDDAR Plaintext. Contains an NTT-applied + encoded message. + }]; + let asmName = "pt"; +} + +def Cheddar_Constant : Cheddar_Type<"Constant", "constant"> { + let description = [{ + This type represents a CHEDDAR Constant. A scalar in RNS form, + used for efficient ciphertext-scalar operations. + }]; + let asmName = "const"; +} + +// Key types + +def Cheddar_EvalKey : Cheddar_Type<"EvalKey", "eval_key"> { + let description = [{ + This type represents a single CHEDDAR EvaluationKey. + }]; + let asmName = "evk"; +} + +def Cheddar_EvkMap : Cheddar_Type<"EvkMap", "evk_map"> { + let description = [{ + This type represents a CHEDDAR EvkMap, which bundles all evaluation + keys (multiplication, rotation, conjugation, etc.) into a single map. + }]; + let asmName = "evk_map"; +} + +// Type aliases for op constraints +def Cheddar_CiphertextOrPlaintext : AnyTypeOf<[Cheddar_Ciphertext, Cheddar_Plaintext]>; +def Cheddar_CiphertextOrConstant : AnyTypeOf<[Cheddar_Ciphertext, Cheddar_Constant]>; + +#endif // LIB_DIALECT_CHEDDAR_IR_CHEDDARTYPES_TD_ diff --git a/lib/Dialect/Cheddar/Transforms/BUILD b/lib/Dialect/Cheddar/Transforms/BUILD new file mode 100644 index 0000000000..8923c9b40d --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/BUILD @@ -0,0 +1,48 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "FuseOps", + srcs = ["FuseOps.cpp"], + hdrs = ["FuseOps.h"], + deps = [ + ":fuse_ops_inc_gen", + "@heir//lib/Dialect/Cheddar/IR:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "ConfigureCryptoContext", + srcs = ["ConfigureCryptoContext.cpp"], + hdrs = ["ConfigureCryptoContext.h"], + deps = [ + ":configure_crypto_context_inc_gen", + "@heir//lib/Dialect/CKKS/IR:Dialect", + "@heir//lib/Dialect/Cheddar/IR:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +add_heir_transforms( + generated_target_name = "fuse_ops_inc_gen", + header_filename = "FuseOps.h.inc", + pass_name = "CheddarFuseOps", + td_file = "FuseOps.td", +) + +add_heir_transforms( + generated_target_name = "configure_crypto_context_inc_gen", + header_filename = "ConfigureCryptoContext.h.inc", + pass_name = "ConfigureCryptoContext", + td_file = "ConfigureCryptoContext.td", +) diff --git a/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.cpp b/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.cpp new file mode 100644 index 0000000000..469d33d8f0 --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.cpp @@ -0,0 +1,52 @@ +#include "lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h" + +#include "lib/Dialect/CKKS/IR/CKKSAttributes.h" +#include "lib/Dialect/CKKS/IR/CKKSDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarAttributes.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project + +namespace mlir::heir::cheddar { + +#define GEN_PASS_DEF_CONFIGURECRYPTOCONTEXT +#include "lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h.inc" + +struct ConfigureCryptoContext + : public impl::ConfigureCryptoContextBase { + using ConfigureCryptoContextBase::ConfigureCryptoContextBase; + + void runOnOperation() override { + auto moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + auto schemeParamAttr = moduleOp->getAttrOfType( + ckks::CKKSDialect::kSchemeParamAttrName); + + if (schemeParamAttr) { + int64_t logN = schemeParamAttr.getLogN(); + int64_t logDefaultScale = schemeParamAttr.getLogDefaultScale(); + + moduleOp->setAttr("cheddar.logN", + IntegerAttr::get(IntegerType::get(ctx, 64), logN)); + moduleOp->setAttr( + "cheddar.logDefaultScale", + IntegerAttr::get(IntegerType::get(ctx, 64), logDefaultScale)); + + if (auto Q = schemeParamAttr.getQ()) { + moduleOp->setAttr("cheddar.Q", Q); + } + + if (auto P = schemeParamAttr.getP()) { + moduleOp->setAttr("cheddar.P", P); + } + + moduleOp->removeAttr(ckks::CKKSDialect::kSchemeParamAttrName); + } + + moduleOp->removeAttr("scheme.ckks"); + } +}; + +} // namespace mlir::heir::cheddar diff --git a/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h b/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h new file mode 100644 index 0000000000..77e79f5c15 --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_CHEDDAR_TRANSFORMS_CONFIGURECRYPTOCONTEXT_H_ +#define LIB_DIALECT_CHEDDAR_TRANSFORMS_CONFIGURECRYPTOCONTEXT_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir::cheddar { + +#define GEN_PASS_DECL +#include "lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h.inc" + +} // namespace mlir::heir::cheddar + +#endif // LIB_DIALECT_CHEDDAR_TRANSFORMS_CONFIGURECRYPTOCONTEXT_H_ diff --git a/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.td b/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.td new file mode 100644 index 0000000000..c113d91401 --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.td @@ -0,0 +1,28 @@ +#ifndef LIB_DIALECT_CHEDDAR_TRANSFORMS_CONFIGURECRYPTOCONTEXT_TD_ +#define LIB_DIALECT_CHEDDAR_TRANSFORMS_CONFIGURECRYPTOCONTEXT_TD_ + +include "mlir/Pass/PassBase.td" + +def ConfigureCryptoContext : Pass<"cheddar-configure-crypto-context"> { + let summary = "Configure CHEDDAR crypto context from scheme parameters."; + + let description = [{ + This pass reads the `ckks.schemeParam` module attribute, converts its + fields (logN, logDefaultScale, Q, P) into CHEDDAR-specific module + attributes, and removes the CKKS scheme attributes. + + After this pass, the module no longer depends on the CKKS dialect. + }]; + + let dependentDialects = [ + "mlir::heir::cheddar::CheddarDialect", + ]; + + let options = [ + Option<"entryFunction", "entry-function", "std::string", + /*default=*/"", "Default entry function " + "name of entry function.">, + ]; +} + +#endif // LIB_DIALECT_CHEDDAR_TRANSFORMS_CONFIGURECRYPTOCONTEXT_TD_ diff --git a/lib/Dialect/Cheddar/Transforms/FuseOps.cpp b/lib/Dialect/Cheddar/Transforms/FuseOps.cpp new file mode 100644 index 0000000000..5bc8ab8c8c --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/FuseOps.cpp @@ -0,0 +1,191 @@ +#include "lib/Dialect/Cheddar/Transforms/FuseOps.h" + +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +namespace mlir::heir::cheddar { + +//===----------------------------------------------------------------------===// +// Fusion patterns +//===----------------------------------------------------------------------===// + +// Pattern: mult + relinearize + rescale -> hmult(rescale=true) +struct FuseMultRelinRescale : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RescaleOp rescaleOp, + PatternRewriter &rewriter) const override { + // Check that the input to rescale is a relinearize + auto relinOp = rescaleOp.getInput().getDefiningOp(); + if (!relinOp || !relinOp.getResult().hasOneUse()) return failure(); + + // Check that the input to relinearize is a mult + auto multOp = relinOp.getInput().getDefiningOp(); + if (!multOp || !multOp.getResult().hasOneUse()) return failure(); + + // Fuse into HMult with rescale=true + rewriter.replaceOpWithNewOp( + rescaleOp, rescaleOp.getOutput().getType(), multOp.getCtx(), + multOp.getLhs(), multOp.getRhs(), relinOp.getMultKey(), + /*rescale=*/rewriter.getBoolAttr(true)); + + // Clean up now-dead ops + rewriter.eraseOp(relinOp); + rewriter.eraseOp(multOp); + return success(); + } +}; + +// Pattern: mult + relinearize -> hmult(rescale=false) +struct FuseMultRelin : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RelinearizeOp relinOp, + PatternRewriter &rewriter) const override { + // Don't match if this relin feeds into a rescale (handled by + // FuseMultRelinRescale) + if (relinOp.getResult().hasOneUse()) { + auto *user = *relinOp.getResult().getUsers().begin(); + if (isa(user)) return failure(); + } + + auto multOp = relinOp.getInput().getDefiningOp(); + if (!multOp || !multOp.getResult().hasOneUse()) return failure(); + + rewriter.replaceOpWithNewOp( + relinOp, relinOp.getOutput().getType(), multOp.getCtx(), + multOp.getLhs(), multOp.getRhs(), relinOp.getMultKey(), + /*rescale=*/rewriter.getBoolAttr(false)); + + rewriter.eraseOp(multOp); + return success(); + } +}; + +// Pattern: mult + relinearize_rescale -> hmult(rescale=true) +// (In case we already have a fused relin+rescale but not the full triple) +struct FuseMultRelinRescaleFused + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RelinearizeRescaleOp relinRescaleOp, + PatternRewriter &rewriter) const override { + auto multOp = relinRescaleOp.getInput().getDefiningOp(); + if (!multOp || !multOp.getResult().hasOneUse()) return failure(); + + rewriter.replaceOpWithNewOp( + relinRescaleOp, relinRescaleOp.getOutput().getType(), multOp.getCtx(), + multOp.getLhs(), multOp.getRhs(), relinRescaleOp.getMultKey(), + /*rescale=*/rewriter.getBoolAttr(true)); + + rewriter.eraseOp(multOp); + return success(); + } +}; + +// Pattern: hrot(a) + b -> hrot_add(a, b) +// Matches: %rotated = cheddar.hrot ...; %sum = cheddar.add %ctx, %rotated, %b +struct FuseHRotAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AddOp addOp, + PatternRewriter &rewriter) const override { + // Check if either operand is an hrot with a single use + HRotOp hrotOp = nullptr; + Value otherOperand; + + if (auto lhsHrot = addOp.getLhs().getDefiningOp()) { + if (lhsHrot.getResult().hasOneUse()) { + hrotOp = lhsHrot; + otherOperand = addOp.getRhs(); + } + } + if (!hrotOp) { + if (auto rhsHrot = addOp.getRhs().getDefiningOp()) { + if (rhsHrot.getResult().hasOneUse()) { + hrotOp = rhsHrot; + otherOperand = addOp.getLhs(); + } + } + } + if (!hrotOp) return failure(); + + // Only fuse static-shift rotations into HRotAdd + auto staticShift = hrotOp.getStaticShift(); + if (!staticShift) return failure(); + + rewriter.replaceOpWithNewOp( + addOp, addOp.getOutput().getType(), hrotOp.getCtx(), hrotOp.getInput(), + otherOperand, hrotOp.getRotKey(), *staticShift); + + rewriter.eraseOp(hrotOp); + return success(); + } +}; + +// Pattern: hconj(a) + b -> hconj_add(a, b) +struct FuseHConjAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AddOp addOp, + PatternRewriter &rewriter) const override { + HConjOp hconjOp = nullptr; + Value otherOperand; + + if (auto lhsHconj = addOp.getLhs().getDefiningOp()) { + if (lhsHconj.getResult().hasOneUse()) { + hconjOp = lhsHconj; + otherOperand = addOp.getRhs(); + } + } + if (!hconjOp) { + if (auto rhsHconj = addOp.getRhs().getDefiningOp()) { + if (rhsHconj.getResult().hasOneUse()) { + hconjOp = rhsHconj; + otherOperand = addOp.getLhs(); + } + } + } + if (!hconjOp) return failure(); + + rewriter.replaceOpWithNewOp( + addOp, addOp.getOutput().getType(), hconjOp.getCtx(), + hconjOp.getInput(), otherOperand, hconjOp.getConjKey()); + + rewriter.eraseOp(hconjOp); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass definition +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_CHEDDARFUSEOPS +#include "lib/Dialect/Cheddar/Transforms/FuseOps.h.inc" + +struct CheddarFuseOps : public impl::CheddarFuseOpsBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // Fusion patterns. Order matters: try the longest fusions first. + patterns.add(context, /*benefit=*/3); + patterns.add(context, /*benefit=*/2); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::heir::cheddar diff --git a/lib/Dialect/Cheddar/Transforms/FuseOps.h b/lib/Dialect/Cheddar/Transforms/FuseOps.h new file mode 100644 index 0000000000..74fcf2b307 --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/FuseOps.h @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_CHEDDAR_TRANSFORMS_FUSEOPS_H_ +#define LIB_DIALECT_CHEDDAR_TRANSFORMS_FUSEOPS_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir::cheddar { + +#define GEN_PASS_DECL +#include "lib/Dialect/Cheddar/Transforms/FuseOps.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/Cheddar/Transforms/FuseOps.h.inc" + +} // namespace mlir::heir::cheddar + +#endif // LIB_DIALECT_CHEDDAR_TRANSFORMS_FUSEOPS_H_ diff --git a/lib/Dialect/Cheddar/Transforms/FuseOps.td b/lib/Dialect/Cheddar/Transforms/FuseOps.td new file mode 100644 index 0000000000..ead026408a --- /dev/null +++ b/lib/Dialect/Cheddar/Transforms/FuseOps.td @@ -0,0 +1,27 @@ +#ifndef LIB_DIALECT_CHEDDAR_TRANSFORMS_FUSEOPS_TD_ +#define LIB_DIALECT_CHEDDAR_TRANSFORMS_FUSEOPS_TD_ + +include "mlir/Pass/PassBase.td" + +def CheddarFuseOps : Pass<"cheddar-fuse-ops"> { + let summary = "Fuse CHEDDAR ops into compound GPU kernel operations."; + + let description = [{ + This pass fuses sequences of CHEDDAR ops into compound ops that + map to single fused GPU kernels in the CHEDDAR library: + + - mult + relinearize + rescale -> hmult (with rescale=true) + - mult + relinearize -> hmult (with rescale=false) + - hrot + add -> hrot_add (fused rotation + addition) + - hconj + add -> hconj_add (fused conjugation + addition) + + These fused operations are significantly faster on GPU as they + avoid intermediate memory allocations and kernel launch overhead. + }]; + + let dependentDialects = [ + "mlir::heir::cheddar::CheddarDialect", + ]; +} + +#endif // LIB_DIALECT_CHEDDAR_TRANSFORMS_FUSEOPS_TD_ diff --git a/lib/Dialect/LWE/Conversions/LWEToCheddar/BUILD b/lib/Dialect/LWE/Conversions/LWEToCheddar/BUILD new file mode 100644 index 0000000000..326060dec4 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToCheddar/BUILD @@ -0,0 +1,38 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "LWEToCheddar", + srcs = ["LWEToCheddar.cpp"], + hdrs = [ + "LWEToCheddar.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect:ModuleAttributes", + "@heir//lib/Dialect/CKKS/IR:Dialect", + "@heir//lib/Dialect/Cheddar/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Utils", + "@heir//lib/Utils:ConversionUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + ], +) + +add_heir_transforms( + header_filename = "LWEToCheddar.h.inc", + pass_name = "LWEToCheddar", + td_file = "LWEToCheddar.td", +) diff --git a/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.cpp b/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.cpp new file mode 100644 index 0000000000..2b4705b956 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.cpp @@ -0,0 +1,868 @@ +#include "lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h" + +#include +#include +#include +#include +#include + +#include "lib/Dialect/CKKS/IR/CKKSAttributes.h" +#include "lib/Dialect/CKKS/IR/CKKSDialect.h" +#include "lib/Dialect/CKKS/IR/CKKSOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Dialect/LWE/IR/LWEAttributes.h" +#include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Dialect/ModuleAttributes.h" +#include "lib/Utils/ConversionUtils.h" +#include "lib/Utils/Utils.h" +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project + +#define DEBUG_TYPE "lwe-to-cheddar" + +namespace mlir::heir::lwe { + +static FailureOr getCKKSLogDefaultScale(Operation* op) { + auto moduleOp = op->getParentOfType(); + if (!moduleOp) { + return failure(); + } + auto schemeParamAttr = moduleOp->getAttrOfType( + ckks::CKKSDialect::kSchemeParamAttrName); + if (!schemeParamAttr) { + return failure(); + } + return schemeParamAttr.getLogDefaultScale(); +} + +enum class CheddarLevelReduceBucketKind { + kCanonical, + kScaledOnce, +}; + +static FailureOr getCheddarLevelReduceBucketKind( + Operation* op, lwe::LWECiphertextType ctType) { + auto logDefaultScale = getCKKSLogDefaultScale(op); + if (failed(logDefaultScale)) { + return failure(); + } + + // Scaling factors on CKKS encodings are stored log-additively on main (a + // multiply adds the log-scales). The "canonical" scale after the initial + // encoding is `logDefaultScale`; one post-multiply-without-rescale bucket + // is `2 * logDefaultScale`. + int64_t logScale = getScalingFactorFromEncodingAttr( + ctType.getPlaintextSpace().getEncoding()); + + if (logScale == *logDefaultScale) { + return CheddarLevelReduceBucketKind::kCanonical; + } + if (logScale == 2 * *logDefaultScale) { + return CheddarLevelReduceBucketKind::kScaledOnce; + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// Type converter +//===----------------------------------------------------------------------===// + +class ToCheddarTypeConverter : public TypeConverter { + public: + ToCheddarTypeConverter(MLIRContext* ctx) { + addConversion([](Type type) { return type; }); + addConversion([ctx](lwe::LWECiphertextType type) -> Type { + return cheddar::CiphertextType::get(ctx); + }); + addConversion([ctx](lwe::LWEPlaintextType type) -> Type { + return cheddar::PlaintextType::get(ctx); + }); + addConversion([ctx](lwe::LWEPublicKeyType type) -> Type { + LLVM_DEBUG(llvm::dbgs() + << "Converting LWEPublicKeyType -> UserInterfaceType\n"); + return cheddar::UserInterfaceType::get(ctx); + }); + addConversion([ctx](lwe::LWESecretKeyType type) -> Type { + LLVM_DEBUG(llvm::dbgs() + << "Converting LWESecretKeyType -> UserInterfaceType\n"); + return cheddar::UserInterfaceType::get(ctx); + }); + addConversion([this](RankedTensorType type) -> Type { + return RankedTensorType::get(type.getShape(), + this->convertType(type.getElementType())); + }); + } +}; + +//===----------------------------------------------------------------------===// +// Helper: get contextual arguments by type +//===----------------------------------------------------------------------===// + +namespace { + +template +bool containsArgumentOfDialect(Operation* op) { + auto funcOp = dyn_cast(op); + if (!funcOp) { + return false; + } + return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) { + return DialectEqual()( + &getElementTypeOrSelf(argType).getDialect()); + }); +} + +template +FailureOr getContextualArg(Operation* op) { + auto result = getContextualArgFromFunc(op); + if (failed(result)) { + return op->emitOpError() + << "Found op in a function without a required CHEDDAR context " + "argument. Did the AddEvaluatorArg pattern fail to run?"; + } + return result.value(); +} + +FailureOr getContextualArg(Operation* op, Type type) { + return getContextualArgFromFunc(op, type); +} + +SmallVector getRequiredCheddarContextTypes( + Operation* op, + const std::vector>& evaluators) { + SmallVector requiredTypes; + for (const auto& evaluator : evaluators) { + if (evaluator.second(op)) { + requiredTypes.push_back(evaluator.first); + } + } + return requiredTypes; +} + +bool hasLeadingTypes(TypeRange actualTypes, ArrayRef requiredTypes) { + if (actualTypes.size() < requiredTypes.size()) { + return false; + } + return llvm::equal(requiredTypes, + actualTypes.take_front(requiredTypes.size())); +} + +void addRequiredCheddarContextArgs( + ModuleOp module, + const std::vector>& evaluators) { + module.walk([&](func::FuncOp funcOp) { + SmallVector requiredTypes = + getRequiredCheddarContextTypes(funcOp, evaluators); + if (requiredTypes.empty() || + hasLeadingTypes(funcOp.getArgumentTypes(), requiredTypes)) { + return; + } + + SmallVector argIndices(requiredTypes.size(), 0); + SmallVector argAttrs(requiredTypes.size(), nullptr); + SmallVector argLocs(requiredTypes.size(), funcOp.getLoc()); + (void)funcOp.insertArguments(argIndices, requiredTypes, argAttrs, argLocs); + }); +} + +LogicalResult addRequiredCheddarContextOperandsToCalls( + ModuleOp module, + const std::vector>& evaluators) { + LogicalResult result = success(); + module.walk([&](func::CallOp callOp) { + if (failed(result)) { + return WalkResult::interrupt(); + } + + auto callee = getCalledFunction(callOp); + if (failed(callee)) { + result = callOp.emitOpError("could not find callee function"); + return WalkResult::interrupt(); + } + + SmallVector requiredTypes = + getRequiredCheddarContextTypes(callee.value(), evaluators); + if (requiredTypes.empty() || + hasLeadingTypes(callOp.getOperandTypes(), requiredTypes)) { + return WalkResult::advance(); + } + + SmallVector contextOperands; + for (Type requiredType : requiredTypes) { + auto contextOperand = + getContextualArg(callOp.getOperation(), requiredType); + if (failed(contextOperand)) { + result = failure(); + return WalkResult::interrupt(); + } + contextOperands.push_back(contextOperand.value()); + } + + callOp->insertOperands(0, contextOperands); + return WalkResult::advance(); + }); + return result; +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +// Binary ct-ct operations: ckks.add -> cheddar.add, etc. +template +struct ConvertCKKSBinOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + CKKSOp op, typename CKKSOp::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getLhs(), adaptor.getRhs()); + return success(); + } +}; + +// LWE R* ops (used by torch-linalg-to-ckks pipeline) +using ConvertRAddOp = ConvertCKKSBinOp; +using ConvertRSubOp = ConvertCKKSBinOp; +using ConvertRMulOp = ConvertCKKSBinOp; + +// Ct-pt operations +template +struct ConvertCKKSPlainOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + CKKSOp op, typename CKKSOp::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + + // Ensure ciphertext is first operand (CHEDDAR convention) + Value ciphertext = adaptor.getLhs(); + Value plaintext = adaptor.getRhs(); + if (!isa(adaptor.getLhs().getType())) { + ciphertext = adaptor.getRhs(); + plaintext = adaptor.getLhs(); + } + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), ciphertext, plaintext); + return success(); + } +}; + +using ConvertRAddPlainOp = + ConvertCKKSPlainOp; +using ConvertRMulPlainOp = + ConvertCKKSPlainOp; + +// SubPlain needs special handling: when plaintext is LHS (pt - ct), +// we emit neg(ct) + pt instead of sub_plain with swapped operands. +template +struct ConvertSubPlainOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubPlainOp op, typename SubPlainOp::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + + auto outType = this->typeConverter->convertType(op.getOutput().getType()); + + if (isa(adaptor.getLhs().getType())) { + // ct - pt: use sub_plain directly + rewriter.replaceOpWithNewOp( + op, outType, ctx.value(), adaptor.getLhs(), adaptor.getRhs()); + } else { + // pt - ct: emit neg(ct) + add_plain(neg_ct, pt) + auto negated = cheddar::NegOp::create(rewriter, op.getLoc(), outType, + ctx.value(), adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + op, outType, ctx.value(), negated, adaptor.getLhs()); + } + return success(); + } +}; + +using ConvertRSubPlainOp = ConvertSubPlainOp; + +// Negate +struct ConvertCKKSNegateOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ckks::NegateOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getInput()); + return success(); + } +}; + +// Relinearize — needs the multiplication key +struct ConvertCKKSRelinOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ckks::RelinearizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + auto ui = getContextualArg(op.getOperation()); + if (failed(ui)) return ui; + + // Get the multiplication key from the UI + auto multKey = cheddar::GetMultKeyOp::create( + rewriter, op.getLoc(), cheddar::EvalKeyType::get(getContext()), + ui.value()); + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getInput(), multKey); + return success(); + } +}; + +// Rescale (mod reduce in CKKS) +struct ConvertCKKSRescaleOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ckks::RescaleOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getInput()); + return success(); + } +}; + +// Rotate +struct ConvertCKKSRotateOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ckks::RotateOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + auto ui = getContextualArg(op.getOperation()); + if (failed(ui)) return ui; + + Value dynamicShift = adaptor.getDynamicShift(); + IntegerAttr staticShift = op.getStaticShiftAttr(); + if (!staticShift && !dynamicShift) { + return rewriter.notifyMatchFailure( + op, "rotate op must have either static or dynamic shift"); + } + + if (dynamicShift) { + // Dynamic shift: the emitter will inline ui.GetRotationKey(shift). + // Create a placeholder GetRotKeyOp that the emitter traces back to + // the UserInterface for the key lookup. + auto rotKey = cheddar::GetRotKeyOp::create( + rewriter, op.getLoc(), cheddar::EvalKeyType::get(getContext()), + ui.value(), + rewriter.getI64IntegerAttr( + cheddar::kDynamicRotationKeyDistanceSentinel)); + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getInput(), rotKey, dynamicShift, + /*static_shift=*/nullptr); + } else { + // Static shift: get the rotation key at lowering time. + auto rotKey = cheddar::GetRotKeyOp::create( + rewriter, op.getLoc(), cheddar::EvalKeyType::get(getContext()), + ui.value(), staticShift); + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getInput(), rotKey, /*dynamic_shift=*/Value(), + staticShift); + } + return success(); + } +}; + +// Level reduce +struct ConvertCKKSLevelReduceOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ckks::LevelReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + + // Derive target level from the output ciphertext's modulus chain. + auto outputCtType = dyn_cast( + getElementTypeOrSelf(op.getOutput().getType())); + auto inputCtType = + dyn_cast(getElementTypeOrSelf(op.getInput())); + int64_t targetLevelVal = + outputCtType ? outputCtType.getModulusChain().getCurrent() : 0; + + if (!outputCtType || !inputCtType) { + return op.emitOpError() << "expected ciphertext input and output types"; + } + + auto bucketKind = + getCheddarLevelReduceBucketKind(op.getOperation(), outputCtType); + if (failed(bucketKind)) { + return op.emitOpError() + << "unsupported CHEDDAR level_reduce scaling factor"; + } + + auto cheddarCtType = + this->typeConverter->convertType(op.getOutput().getType()); + if (*bucketKind == CheddarLevelReduceBucketKind::kCanonical) { + auto targetLevel = rewriter.getI64IntegerAttr(targetLevelVal); + rewriter.replaceOpWithNewOp( + op, cheddarCtType, ctx.value(), adaptor.getInput(), targetLevel); + return success(); + } + + auto encoder = getContextualArg(op.getOperation()); + if (failed(encoder)) return encoder; + + auto logDefaultScale = getCKKSLogDefaultScale(op.getOperation()); + if (failed(logDefaultScale)) { + return op.emitOpError() << "missing CKKS scheme parameter"; + } + + int64_t currentLevel = inputCtType.getModulusChain().getCurrent(); + if (targetLevelVal < 0 || targetLevelVal > currentLevel) { + return op.emitOpError() + << "cannot level_reduce from level " << currentLevel + << " to incompatible target level " << targetLevelVal; + } + Value current = adaptor.getInput(); + auto one = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getF64FloatAttr(1.0)); + while (currentLevel > targetLevelVal) { + current = cheddar::RescaleOp::create(rewriter, op.getLoc(), cheddarCtType, + ctx.value(), current); + --currentLevel; + auto encodedOne = cheddar::EncodeConstantOp::create( + rewriter, op.getLoc(), cheddar::ConstantType::get(getContext()), + encoder.value(), one, rewriter.getI64IntegerAttr(currentLevel), + rewriter.getI64IntegerAttr(*logDefaultScale)); + current = + cheddar::MultConstOp::create(rewriter, op.getLoc(), cheddarCtType, + ctx.value(), current, encodedOne); + } + + rewriter.replaceOp(op, current); + return success(); + } +}; + +// Bootstrap +struct ConvertCKKSBootstrapOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ckks::BootstrapOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ctx = getContextualArg(op.getOperation()); + if (failed(ctx)) return ctx; + auto ui = getContextualArg(op.getOperation()); + if (failed(ui)) return ui; + + auto evkMap = cheddar::GetEvkMapOp::create( + rewriter, op.getLoc(), cheddar::EvkMapType::get(getContext()), + ui.value()); + + rewriter.replaceOpWithNewOp( + op, this->typeConverter->convertType(op.getOutput().getType()), + ctx.value(), adaptor.getInput(), evkMap); + return success(); + } +}; + +// Encode +struct ConvertLWEEncodeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::RLWEEncodeOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto encoder = getContextualArg(op.getOperation()); + if (failed(encoder)) return encoder; + + auto invEncoding = + dyn_cast(op.getEncoding()); + if (!invEncoding) { + return op.emitOpError() + << "requires inverse-canonical CKKS plaintext encoding for " + "CHEDDAR lowering"; + } + // In HEIR's main LWE encoding, the scaling factor stored on + // inverse_canonical_encoding is already log2-of-scale (additive for CKKS + // multiplies), so it maps directly onto CHEDDAR's logScale field. + int64_t logScale = invEncoding.getScalingFactor(); + + // lwe.rlwe_encode doesn't carry a plaintext level on main, so encode at + // the top of the modulus chain. Cross-level plaintexts (e.g. a mask used + // before and after a rescale, or mult_plain with a ciphertext that has + // already been rescaled) need proper level management and are not + // supported yet — those tests should be kept out of CI until MGMT lands. + int64_t level = 0; + // When the chain is longer than the circuit needs (heir.level_offset > 0), + // shift the level so CHEDDAR encodes plaintexts at the correct prime set. + if (auto levelOffset = + op->getParentOfType()->getAttrOfType( + "heir.level_offset")) { + level += levelOffset.getInt(); + } + auto ptTy = cheddar::PlaintextType::get(getContext()); + rewriter.replaceOpWithNewOp( + op, ptTy, encoder.value(), adaptor.getInput(), + rewriter.getI64IntegerAttr(level), + rewriter.getI64IntegerAttr(logScale)); + return success(); + } +}; + +// Decrypt +struct ConvertLWEDecryptOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::RLWEDecryptOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ui = getContextualArg(op.getOperation()); + if (failed(ui)) return ui; + + rewriter.replaceOpWithNewOp( + op, cheddar::PlaintextType::get(getContext()), ui.value(), + adaptor.getInput()); + return success(); + } +}; + +// Encrypt +struct ConvertLWEEncryptOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::RLWEEncryptOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto ui = getContextualArg(op.getOperation()); + if (failed(ui)) return ui; + + rewriter.replaceOpWithNewOp( + op, cheddar::CiphertextType::get(getContext()), ui.value(), + adaptor.getInput()); + return success(); + } +}; + +// Decode +struct ConvertLWEDecodeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::RLWEDecodeOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto encoder = getContextualArg(op.getOperation()); + if (failed(encoder)) return encoder; + + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), encoder.value(), adaptor.getInput()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// AddEvaluatorArg pattern +//===----------------------------------------------------------------------===// + +struct AddCheddarContextArg : public OpConversionPattern { + AddCheddarContextArg( + mlir::MLIRContext* context, + const std::vector>& evaluators) + : OpConversionPattern(context, /* benefit= */ 2), + evaluators(evaluators) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + LLVM_DEBUG(llvm::dbgs() + << "AddCheddarContextArg for func " << op.getName() << "\n"); + SmallVector selectedTypes = + getRequiredCheddarContextTypes(op, evaluators); + + if (selectedTypes.empty()) { + return rewriter.notifyMatchFailure(op, "no CHEDDAR context needed"); + } + if (hasLeadingTypes(op.getArgumentTypes(), selectedTypes)) { + return rewriter.notifyMatchFailure( + op, "CHEDDAR context arguments already present"); + } + + for (Type selectedType : selectedTypes) { + LLVM_DEBUG(llvm::dbgs() + << " Adding context arg of type: " << selectedType << "\n"); + } + + SmallVector argIndices(selectedTypes.size(), 0); + SmallVector argAttrs(selectedTypes.size(), nullptr); + SmallVector argLocs(selectedTypes.size(), op.getLoc()); + + rewriter.modifyOpInPlace(op, [&] { + SmallVector indices(selectedTypes.size(), 0); + (void)op.insertArguments(indices, selectedTypes, argAttrs, argLocs); + }); + return success(); + } + + private: + std::vector> evaluators; +}; + +struct ConvertCheddarFuncCallOp : public OpConversionPattern { + ConvertCheddarFuncCallOp( + const mlir::TypeConverter& typeConverter, mlir::MLIRContext* context, + const std::vector>& evaluators) + : OpConversionPattern(typeConverter, context, + /*benefit=*/2), + evaluators(evaluators) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::CallOp op, typename func::CallOp::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto funcOp = getCalledFunction(op); + if (failed(funcOp)) { + return rewriter.notifyMatchFailure(op, "could not find callee function"); + } + + SmallVector requiredTypes = + getRequiredCheddarContextTypes(funcOp.value(), evaluators); + if (hasLeadingTypes(op.getOperandTypes(), requiredTypes) && + this->typeConverter->isLegal(op)) { + return rewriter.notifyMatchFailure( + op, "call already has required CHEDDAR signature"); + } + + SmallVector selectedValues; + for (Type requiredType : requiredTypes) { + auto result = getContextualArg(op.getOperation(), requiredType); + if (failed(result)) { + return rewriter.notifyMatchFailure(op, + "missing required CHEDDAR context"); + } + selectedValues.push_back(result.value()); + } + + SmallVector newOperands; + for (auto v : selectedValues) newOperands.push_back(v); + for (auto operand : adaptor.getOperands()) newOperands.push_back(operand); + + SmallVector convertedResultTypes; + if (failed(this->typeConverter->convertTypes(op.getResultTypes(), + convertedResultTypes))) { + return rewriter.notifyMatchFailure(op, + "could not convert call result types"); + } + + SmallVector dialectAttrs(op->getDialectAttrs()); + rewriter + .replaceOpWithNewOp(op, op.getCallee(), + convertedResultTypes, newOperands) + ->setDialectAttrs(dialectAttrs); + return success(); + } + + private: + std::vector> evaluators; +}; + +} // anonymous namespace + +//===----------------------------------------------------------------------===// +// Pass implementation +//===----------------------------------------------------------------------===// + +#define GEN_PASS_DEF_LWETOCHEDDAR +#include "lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h.inc" + +struct LWEToCheddar : public impl::LWEToCheddarBase { + // Workaround: dialect conversion may drop dialect attrs on func::CallOp. + // Save before conversion and restore after. + SmallVector> funcCallOpDialectAttrs; + + void saveFuncCallOpDialectAttrs() { + funcCallOpDialectAttrs.clear(); + auto* module = getOperation(); + module->walk([&](func::CallOp callOp) { + SmallVector dialectAttrs; + for (auto namedAttr : callOp->getDialectAttrs()) { + dialectAttrs.push_back(namedAttr); + } + funcCallOpDialectAttrs.push_back(dialectAttrs); + }); + } + + void restoreFuncCallOpDialectAttrs() { + auto* module = getOperation(); + auto* iter = funcCallOpDialectAttrs.begin(); + module->walk([&](func::CallOp callOp) { + callOp->setDialectAttrs(*iter); + ++iter; + }); + } + + void runOnOperation() override { + saveFuncCallOpDialectAttrs(); + + MLIRContext* context = &getContext(); + auto* module = getOperation(); + ToCheddarTypeConverter typeConverter(context); + + // Only run for CKKS modules (CHEDDAR is CKKS-only) + if (!moduleIsCKKS(module)) { + module->emitOpError("CHEDDAR backend only supports CKKS scheme"); + return signalPassFailure(); + } + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + target + .addIllegalOp(); + + RewritePatternSet patterns(context); + addStructuralConversionPatterns(typeConverter, patterns, target); + addTensorConversionPatterns(typeConverter, patterns, target); + + // Predicate: function contains CKKS/LWE ops or operands. This must remain + // true while dialect conversion is in flight, even after the generic func + // signature conversion has already rewritten LWE-typed arguments. + auto hasCryptoOps = [&](Operation* op) -> bool { + return containsArgumentOfDialect( + op) || + containsDialects(op); + }; + + // Predicate: function contains encode ops. + // Note: unlike Lattigo's containsEncode which uses walkFuncAndCallees to + // transitively walk through call sites, we use a local walk here. This is + // sufficient because AddCheddarContextArg processes each func independently + // and ConvertCheddarFuncCallOp threads context args from caller to callee. + // A caller that invokes a callee with encode ops will already have crypto + // args (since encode ops accompany crypto ops in practice), so the Encoder + // arg is threaded transitively via the hasCryptoOps predicate. + auto hasEncodeOps = [&](Operation* op) -> bool { + auto funcOp = dyn_cast(op); + if (!funcOp) return false; + bool found = false; + funcOp->walk([&](lwe::RLWEEncodeOp) { found = true; }); + return found; + }; + + // CHEDDAR context args to thread through functions + std::vector> evaluators = { + {cheddar::ContextType::get(context), hasCryptoOps}, + {cheddar::EncoderType::get(context), + [&](Operation* op) { return hasCryptoOps(op) || hasEncodeOps(op); }}, + {cheddar::UserInterfaceType::get(context), hasCryptoOps}, + }; + + // CKKS ops (scheme-specific ops that CKKSToLWE leaves in place) + patterns.add(typeConverter, context); + + // LWE R* ops (produced by torch-linalg-to-ckks pipeline) + patterns.add( + typeConverter, context); + + // LWE encrypt/decrypt/encode/decode ops + patterns.add(typeConverter, context); + + addRequiredCheddarContextArgs(cast(getOperation()), evaluators); + if (failed(addRequiredCheddarContextOperandsToCalls( + cast(getOperation()), evaluators))) { + return signalPassFailure(); + } + + // Dynamically legal: func ops that have been converted + target.addDynamicallyLegalOp([&](func::FuncOp op) { + SmallVector requiredTypes = + getRequiredCheddarContextTypes(op, evaluators); + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()) && + hasLeadingTypes(op.getArgumentTypes(), requiredTypes); + }); + + target.addDynamicallyLegalOp([&](func::CallOp op) { + FailureOr callee = getCalledFunction(op); + if (failed(callee)) { + return false; + } + SmallVector requiredTypes = + getRequiredCheddarContextTypes(callee.value(), evaluators); + return typeConverter.isLegal(op) && + hasLeadingTypes(op.getCalleeType().getInputs(), requiredTypes); + }); + + target.markUnknownOpDynamicallyLegal( + [&](Operation* op) -> std::optional { + return typeConverter.isLegal(op); + }); + + ConversionConfig config; + config.allowPatternRollback = false; + if (failed(applyPartialConversion(module, target, std::move(patterns), + config))) { + return signalPassFailure(); + } + + restoreFuncCallOpDialectAttrs(); + } +}; + +} // namespace mlir::heir::lwe + +// Include the generated pass definition +// (must be after the struct definition for the base class to find it) diff --git a/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h b/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h new file mode 100644 index 0000000000..8c03e8a6ea --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_LWE_CONVERSIONS_LWETOCHEDDAR_LWETOCHEDDAR_H_ +#define LIB_DIALECT_LWE_CONVERSIONS_LWETOCHEDDAR_LWETOCHEDDAR_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir::lwe { + +#define GEN_PASS_DECL +#include "lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h.inc" + +} // namespace mlir::heir::lwe + +#endif // LIB_DIALECT_LWE_CONVERSIONS_LWETOCHEDDAR_LWETOCHEDDAR_H_ diff --git a/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.td b/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.td new file mode 100644 index 0000000000..f963e8b060 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.td @@ -0,0 +1,26 @@ +#ifndef LIB_DIALECT_LWE_CONVERSIONS_LWETOCHEDDAR_LWETOCHEDDAR_TD_ +#define LIB_DIALECT_LWE_CONVERSIONS_LWETOCHEDDAR_LWETOCHEDDAR_TD_ + +include "mlir/Pass/PassBase.td" + +def LWEToCheddar : Pass<"lwe-to-cheddar"> { + let summary = "Lower `lwe`/`ckks` to `cheddar` dialect."; + + let description = [{ + This pass lowers the `lwe` and `ckks` dialects to the `cheddar` dialect, + targeting the CHEDDAR GPU FHE library. It also converts `orion` dialect + ops (linear_transform, chebyshev) to their CHEDDAR equivalents. + + The pass threads a context, encoder, and evk_map through function arguments + following the same pattern as the LWE-to-Lattigo pass. + }]; + + let dependentDialects = [ + "mlir::heir::cheddar::CheddarDialect", + "mlir::heir::ckks::CKKSDialect", + "mlir::heir::lwe::LWEDialect", + "mlir::tensor::TensorDialect", + ]; +} + +#endif // LIB_DIALECT_LWE_CONVERSIONS_LWETOCHEDDAR_LWETOCHEDDAR_TD_ diff --git a/lib/Dialect/ModuleAttributes.cpp b/lib/Dialect/ModuleAttributes.cpp index 88cf3802a9..f0dd6768b7 100644 --- a/lib/Dialect/ModuleAttributes.cpp +++ b/lib/Dialect/ModuleAttributes.cpp @@ -101,9 +101,15 @@ bool moduleIsLattigo(Operation* moduleOp) { nullptr; } +bool moduleIsCheddar(Operation* moduleOp) { + return moduleOp->getAttrOfType(kCheddarBackendAttrName) != + nullptr; +} + void moduleClearBackend(Operation* moduleOp) { moduleOp->removeAttr(kOpenfheBackendAttrName); moduleOp->removeAttr(kLattigoBackendAttrName); + moduleOp->removeAttr(kCheddarBackendAttrName); } void moduleSetOpenfhe(Operation* moduleOp) { @@ -118,5 +124,11 @@ void moduleSetLattigo(Operation* moduleOp) { mlir::UnitAttr::get(moduleOp->getContext())); } +void moduleSetCheddar(Operation* moduleOp) { + moduleClearBackend(moduleOp); + moduleOp->setAttr(kCheddarBackendAttrName, + mlir::UnitAttr::get(moduleOp->getContext())); +} + } // namespace heir } // namespace mlir diff --git a/lib/Dialect/ModuleAttributes.h b/lib/Dialect/ModuleAttributes.h index 16b064aa7c..e27bb9b436 100644 --- a/lib/Dialect/ModuleAttributes.h +++ b/lib/Dialect/ModuleAttributes.h @@ -54,14 +54,18 @@ constexpr const static ::llvm::StringLiteral kOpenfheBackendAttrName = "backend.openfhe"; constexpr const static ::llvm::StringLiteral kLattigoBackendAttrName = "backend.lattigo"; +constexpr const static ::llvm::StringLiteral kCheddarBackendAttrName = + "backend.cheddar"; bool moduleIsOpenfhe(Operation* moduleOp); bool moduleIsLattigo(Operation* moduleOp); +bool moduleIsCheddar(Operation* moduleOp); void moduleClearBackend(Operation* moduleOp); void moduleSetOpenfhe(Operation* moduleOp); void moduleSetLattigo(Operation* moduleOp); +void moduleSetCheddar(Operation* moduleOp); // Func attributes for client helpers // diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index 6ec1054925..b038518765 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -5,8 +5,11 @@ #include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h" #include "lib/Dialect/CKKS/Transforms/CKKSToLWE.h" +#include "lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h" +#include "lib/Dialect/Cheddar/Transforms/FuseOps.h" #include "lib/Dialect/Debug/Transforms/Passes.h" #include "lib/Dialect/Debug/Transforms/ValidateNames.h" +#include "lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h" #include "lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" #include "lib/Dialect/LWE/Transforms/AddDebugPort.h" @@ -525,6 +528,39 @@ BackendPipelineBuilder toLattigoPipelineBuilder() { }; } +BackendPipelineBuilder toCheddarPipelineBuilder() { + return [=](OpPassManager& pm, const BackendOptions& options) { + // Convert CKKS to LWE + pm.addPass(ckks::createCKKSToLWE()); + + if (options.debug) { + llvm::errs() << "warning: CHEDDAR backend currently ignores " + "--insert-debug-handler-calls\n"; + } + + // Convert LWE to CHEDDAR + pm.addPass(lwe::createLWEToCheddar()); + + if (options.fuseOps) pm.addPass(cheddar::createCheddarFuseOps()); + + // Simplify, in case the lowering revealed redundancy + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + // Configure crypto context (consume CKKS module attrs) + auto configureCryptoContextOptions = + cheddar::ConfigureCryptoContextOptions{}; + configureCryptoContextOptions.entryFunction = options.entryFunction; + pm.addPass( + cheddar::createConfigureCryptoContext(configureCryptoContextOptions)); + + pm.addPass(createRemoveUnusedPureCall()); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createSymbolDCEPass()); + }; +} + void linalgPreprocessingBuilder(OpPassManager& manager) { manager.addPass(createInlineActivations()); manager.addPass(createActivationCanonicalizations()); diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index e60b1eaeac..300dbf83e9 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -134,6 +134,11 @@ struct BackendOptions : public PassPipelineOptions { llvm::cl::desc("Insert function calls to an externally-defined debug " "function (cf. --lwe-add-debug-port)"), llvm::cl::init(false)}; + PassOptions::Option fuseOps{ + *this, "fuse-ops", + llvm::cl::desc("Fuse sequences of ops into compound GPU kernels " + "(CHEDDAR backend only)"), + llvm::cl::init(false)}; }; using RLWEPipelineBuilder = @@ -158,6 +163,8 @@ BackendPipelineBuilder toOpenFhePipelineBuilder(); BackendPipelineBuilder toLattigoPipelineBuilder(); +BackendPipelineBuilder toCheddarPipelineBuilder(); + // A subpipeline that preprocesses linalg ops to make them more suitable for // FHE. void linalgPreprocessingBuilder(OpPassManager& manager); diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index 1695ea6561..4e3efb1181 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -101,8 +101,11 @@ cc_library( ":PipelineRegistration", "@heir//lib/Dialect/BGV/Conversions/BGVToLWE", "@heir//lib/Dialect/CKKS/Transforms:CKKSToLWE", + "@heir//lib/Dialect/Cheddar/Transforms:ConfigureCryptoContext", + "@heir//lib/Dialect/Cheddar/Transforms:FuseOps", "@heir//lib/Dialect/Debug/Transforms", "@heir//lib/Dialect/Debug/Transforms:ValidateNames", + "@heir//lib/Dialect/LWE/Conversions/LWEToCheddar", "@heir//lib/Dialect/LWE/Conversions/LWEToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", diff --git a/lib/Target/Cheddar/BUILD b/lib/Target/Cheddar/BUILD new file mode 100644 index 0000000000..dc5d2285fa --- /dev/null +++ b/lib/Target/Cheddar/BUILD @@ -0,0 +1,30 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "CheddarEmitter", + srcs = ["CheddarEmitter.cpp"], + hdrs = [ + "CheddarEmitter.h", + "CheddarTemplates.h", + ], + deps = [ + "@heir//lib/Analysis/SelectVariableNames", + "@heir//lib/Dialect/Cheddar/IR:Dialect", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Utils:TargetUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + ], +) diff --git a/lib/Target/Cheddar/CheddarEmitter.cpp b/lib/Target/Cheddar/CheddarEmitter.cpp new file mode 100644 index 0000000000..0a9252d91c --- /dev/null +++ b/lib/Target/Cheddar/CheddarEmitter.cpp @@ -0,0 +1,1781 @@ +#include "lib/Target/Cheddar/CheddarEmitter.h" + +#include +#include + +#include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "lib/Dialect/Cheddar/IR/CheddarTypes.h" +#include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "lib/Target/Cheddar/CheddarTemplates.h" +#include "lib/Utils/RotationUtils.h" +#include "lib/Utils/TargetUtils.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/ManagedStatic.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace cheddar { + +namespace { + +bool isFunctionEntryBlockArg(Value value) { + auto blockArg = dyn_cast(value); + if (!blockArg) return false; + Block *owner = blockArg.getOwner(); + return owner->isEntryBlock() && isa(owner->getParentOp()); +} + +Operation *getAncestorInBlock(Operation *op, Block *block) { + while (op && op->getBlock() != block) op = op->getParentOp(); + return op; +} + +bool hasLaterUsesInSameScope(Value value, Operation *beforeOp) { + Block *block = beforeOp->getBlock(); + for (OpOperand &use : value.getUses()) { + Operation *user = getAncestorInBlock(use.getOwner(), block); + if (!user || user == beforeOp) continue; + if (beforeOp->isBeforeInBlock(user)) return true; + } + return false; +} + +bool canMoveTensorDestIntoResult(Value dest, Operation *consumer) { + if (isFunctionEntryBlockArg(dest)) return false; + return !hasLaterUsesInSameScope(dest, consumer); +} + +bool isContextCopyableMoveOnlyType(Type type) { + return isa(type); +} + +bool canMoveValueIntoConsumer(Value value) { + if (isFunctionEntryBlockArg(value)) return false; + return value.hasOneUse(); +} + +std::string sanitizeCppIdentifier(StringRef name) { + static constexpr llvm::StringLiteral keywords[] = { + "alignas", "alignof", "and", + "and_eq", "asm", "auto", + "bitand", "bitor", "bool", + "break", "case", "catch", + "char", "char8_t", "char16_t", + "char32_t", "class", "compl", + "concept", "const", "consteval", + "constexpr", "constinit", "const_cast", + "continue", "co_await", "co_return", + "co_yield", "decltype", "default", + "delete", "do", "double", + "dynamic_cast", "else", "enum", + "explicit", "export", "extern", + "false", "float", "for", + "friend", "goto", "if", + "inline", "int", "long", + "mutable", "namespace", "new", + "noexcept", "not", "not_eq", + "nullptr", "operator", "or", + "or_eq", "private", "protected", + "public", "register", "reinterpret_cast", + "requires", "return", "short", + "signed", "sizeof", "static", + "static_assert", "static_cast", "struct", + "switch", "template", "this", + "thread_local", "throw", "true", + "try", "typedef", "typeid", + "typename", "union", "unsigned", + "using", "virtual", "void", + "volatile", "wchar_t", "while", + "xor", "xor_eq"}; + if (llvm::is_contained(keywords, name)) { + return (name + "_v").str(); + } + return name.str(); +} + +} // namespace + +CheddarEmitter::CheddarEmitter(raw_ostream &os, + SelectVariableNames *variableNames, + bool use64Bit, const std::string ¶msJsonPath) + : os(os), + variableNames(variableNames), + use64Bit(use64Bit), + paramsJsonPath(paramsJsonPath) {} + +std::string CheddarEmitter::getName(Value value) { + return sanitizeCppIdentifier(variableNames->getNameForValue(value)); +} + +std::string CheddarEmitter::getContextName(Operation *op) { + if (auto funcOp = op->getParentOfType()) + if (funcOp.getNumArguments() > 0) return getName(funcOp.getArgument(0)); + return ""; +} + +void CheddarEmitter::emitScaleMismatchDebugCheck(StringRef opKind, + StringRef resultName, + Value lhs, Value rhs) { + os << "if (std::getenv(\"HEIR_CHEDDAR_DEBUG_SCALES\")) {\n"; + os.indent(); + os << "double lhs_scale = " << getName(lhs) << ".GetScale();\n"; + os << "double rhs_scale = " << getName(rhs) << ".GetScale();\n"; + os << "if (std::abs(lhs_scale - rhs_scale) >= 1e-12 * lhs_scale) {\n"; + os.indent(); + os << "std::cerr << \"[heir-cheddar] " << opKind << " " << resultName + << " scale mismatch lhs=\" << lhs_scale << \" rhs=\" << rhs_scale" + << " << std::endl;\n"; + os.unindent(); + os << "}\n"; + os.unindent(); + os << "}\n"; +} + +LogicalResult CheddarEmitter::emitVectorDeepCopy(Operation *op, + StringRef destName, + StringRef srcName, + Type elemType, + Type tensorType) { + if (!isContextCopyableMoveOnlyType(elemType)) { + return op->emitOpError( + "cannot deep-copy non-ciphertext CHEDDAR tensor elements"); + } + auto typeStr = convertType(elemType); + if (failed(typeStr)) { + return op->emitOpError("failed to convert CHEDDAR tensor element type"); + } + auto tensorTy = cast(tensorType); + int64_t numElements = tensorTy.getNumElements(); + auto ctxName = getContextName(op); + if (ctxName.empty()) { + return op->emitOpError( + "missing CHEDDAR context argument for deep-copying tensor elements"); + } + os << "std::vector<" << *typeStr << "> " << destName << "(" << numElements + << ");\n"; + os << "for (size_t i = 0; i < " << numElements << "; ++i)\n"; + os.indent(); + os << ctxName << "->Copy(" << destName << "[i], " << srcName << "[i]);\n"; + os.unindent(); + return success(); +} + +FailureOr CheddarEmitter::convertType(Type type, bool asArg) { + return llvm::TypeSwitch>(type) + .Case([](auto) { return std::string("CtxPtr"); }) + .Case([asArg](auto) { + return std::string(asArg ? "const Param&" : "Param"); + }) + .Case([](auto) { return std::string("Enc&"); }) + .Case([](auto) { return std::string("UI&"); }) + .Case( + [asArg](auto) { return std::string(asArg ? "const Ct&" : "Ct"); }) + // Pt args are passed non-const because AddPlainOp may call pt.SetScale() + // to match the ciphertext's drifted scale after LinearTransform eval. + .Case( + [asArg](auto) { return std::string(asArg ? "Pt&" : "Pt"); }) + .Case([asArg](auto) { + return std::string(asArg ? "const Const&" : "Const"); + }) + .Case([](auto) { return std::string("const Evk&"); }) + .Case([](auto) { return std::string("const EvkMapT&"); }) + .Case( + [asArg](RankedTensorType type) -> FailureOr { + auto elemType = type.getElementType(); + if (isa(elemType)) { + return std::string(asArg ? "const std::vector&" + : "std::vector"); + } + if (elemType.isF64() || elemType.isF32()) { + return std::string(asArg ? "const std::vector&" + : "std::vector"); + } + if (elemType.isInteger(32) || elemType.isInteger(64) || + elemType.isIndex()) { + return std::string(asArg ? "const std::vector&" + : "std::vector"); + } + if (isa(elemType)) { + return std::string(asArg ? "const std::vector&" + : "std::vector"); + } + if (isa(elemType)) { + // Non-const to allow AddPlainOp's SetScale on elements. + return std::string(asArg ? "std::vector&" + : "std::vector"); + } + return failure(); + }) + .Case([](auto) { return std::string("double"); }) + .Case([](auto) { return std::string("int64_t"); }) + .Case([](IntegerType type) -> FailureOr { + auto width = type.getWidth(); + if (width == 1) return std::string("bool"); + if (width <= 32) return std::string("int32_t"); + return std::string("int64_t"); + }) + .Default([](Type) { return failure(); }); +} + +void CheddarEmitter::emitPrelude(raw_ostream &os) const { + os << kStdIncludes; + os << kCheddarInclude; + if (needsExtensionIncludes) { + os << kCheddarExtensionInclude; + } + if (needsJsonIncludes) { + os << kJsonInclude; + } + os << "\n"; + if (use64Bit) { + os << kTypeAliasPrelude64; + } else { + os << kTypeAliasPrelude32; + } + os << "\n"; +} + +LogicalResult CheddarEmitter::translate(Operation &op) { + LogicalResult status = + llvm::TypeSwitch(op) + .Case([&](auto op) { return printOperation(op); }) + .Case( + [&](auto op) { return printOperation(op); }) + .Case([&](auto op) { return printOperation(op); }) + // Cheddar setup ops + .Case([&](auto op) { return printOperation(op); }) + // Encode/encrypt/decrypt + .Case( + [&](auto op) { return printOperation(op); }) + // Ct-ct arithmetic + .Case( + [&](auto op) { return printOperation(op); }) + // Ct-pt and ct-const + .Case( + [&](auto op) { return printOperation(op); }) + // Unary / level management + .Case( + [&](auto op) { return printOperation(op); }) + // Fused compound ops + .Case( + [&](auto op) { return printOperation(op); }) + // Extension ops + .Case( + [&](auto op) { return printOperation(op); }) + // SCF control flow + .Case( + [&](auto op) { return printOperation(op); }) + // Tensor ops + .Case( + [&](auto op) { return printOperation(op); }) + // Additional arith ops + .Case([&](auto op) { return printOperation(op); }) + .Default([&](Operation &) { + return op.emitOpError("unable to find printer for op"); + }); + return status; +} + +//===----------------------------------------------------------------------===// +// Module / Function +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(ModuleOp moduleOp) { + // Read module-level scheme params set by ConfigureCryptoContext + if (auto attr = + moduleOp->getAttrOfType("cheddar.logDefaultScale")) + logDefaultScale = attr.getInt(); + if (auto attr = moduleOp->getAttrOfType("cheddar.logN")) + logN = attr.getInt(); + if (auto attr = moduleOp->getAttrOfType("cheddar.Q")) + qPrimes.assign(attr.asArrayRef().begin(), attr.asArrayRef().end()); + if (auto attr = moduleOp->getAttrOfType("cheddar.P")) + pPrimes.assign(attr.asArrayRef().begin(), attr.asArrayRef().end()); + + // Collect rotation distances used across the module + moduleOp->walk([&](Operation *op) { + if (auto hrotOp = dyn_cast(op)) { + if (auto staticShift = hrotOp.getStaticShift()) { + rotationDistances.insert(staticShift->getInt()); + } else if (auto dynShift = hrotOp.getDynamicShift()) { + if (auto constOp = dynShift.getDefiningOp()) + if (auto intAttr = dyn_cast(constOp.getValue())) + rotationDistances.insert(intAttr.getInt()); + } + } else if (auto hrotAddOp = dyn_cast(op)) { + rotationDistances.insert(hrotAddOp.getDistanceAttr().getInt()); + } else if (auto linTransOp = dyn_cast(op)) { + // Compute BSGS rotation indices accounting for pre_rotation on + // wrap-around diagonal matrices. + auto diagType = + cast(linTransOp.getDiagonals().getType()); + int64_t ltSlots = diagType.getShape()[1]; + auto diagIdx = linTransOp.getDiagonalIndicesAttr().asArrayRef(); + + // Determine pre_rotation (same logic as in printOperation). + int64_t preRot = 0; + for (auto idx : diagIdx) { + if (idx > ltSlots / 2) { + preRot = idx; + break; + } + } + for (auto idx : diagIdx) { + if (idx > ltSlots / 2 && idx < preRot) preRot = idx; + } + + // Adjusted indices after pre_rotation. + SmallVector adjusted; + for (auto idx : diagIdx) { + adjusted.push_back( + static_cast((idx - preRot + ltSlots) % ltSlots)); + } + + int64_t logBSGS = linTransOp.getLogBabyStepGiantStepRatio().getInt(); + auto rots = lintransRotationIndices(adjusted, ltSlots, logBSGS); + for (int64_t r : rots) rotationDistances.insert(r); + } + }); + + for (Operation &op : moduleOp) { + if (failed(translate(op))) return failure(); + } + + // Emit __configure() if we have the required module-level params + if (logN > 0 && !qPrimes.empty()) { + os << "std::tuple __configure() {\n"; + os.indent(); + + // Main primes (Q) + os << "static std::vector main_primes = {"; + for (size_t i = 0; i < qPrimes.size(); ++i) { + if (i > 0) os << ", "; + os << static_cast(qPrimes[i]) << "ULL"; + } + os << "};\n"; + + // Auxiliary primes (P) + os << "static std::vector aux_primes = {"; + for (size_t i = 0; i < pPrimes.size(); ++i) { + if (i > 0) os << ", "; + os << static_cast(pPrimes[i]) << "ULL"; + } + os << "};\n"; + + // Level config + os << "static std::vector> level_config = []() {\n"; + os.indent(); + os << "std::vector> lc;\n"; + os << "for (int i = 1; i <= static_cast(main_primes.size()); ++i)\n"; + os.indent(); + os << "lc.push_back({i, 0});\n"; + os.unindent(); + os << "return lc;\n"; + os.unindent(); + os << "}();\n"; + + // Parameter must outlive Context (Context stores a reference to it) + os << "static Param param(" << logN << ", static_cast(1ULL << " + << logDefaultScale << "), " + << "static_cast(main_primes.size()) - 1, " + << "level_config, main_primes, aux_primes);\n"; + os << "auto ctx = Context::Create(param);\n"; + os << "UI ui(ctx);\n"; + + // Rotation keys. + if (!rotationDistances.empty()) { + for (int64_t dist : rotationDistances) { + os << "ui.PrepareRotationKey(" << dist + << ", static_cast(main_primes.size()) - 1);\n"; + } + } + + os << "return {ctx, std::move(ui)};\n"; + os.unindent(); + os << "}\n\n"; + } + + return success(); +} + +LogicalResult CheddarEmitter::printOperation(func::FuncOp funcOp) { + // Emit function signature + auto funcType = funcOp.getFunctionType(); + auto resultTypes = funcType.getResults(); + + // Return type + if (resultTypes.empty()) { + os << "void "; + } else if (resultTypes.size() == 1) { + auto typeStr = convertType(resultTypes[0]); + if (failed(typeStr)) + return funcOp.emitOpError("failed to convert return type"); + os << *typeStr << " "; + } else { + // Multiple returns: use std::tuple + os << "std::tuple<"; + for (unsigned i = 0; i < resultTypes.size(); ++i) { + if (i > 0) os << ", "; + auto typeStr = convertType(resultTypes[i]); + if (failed(typeStr)) + return funcOp.emitOpError("failed to convert return type"); + os << *typeStr; + } + os << "> "; + } + + // Function name and arguments + os << funcOp.getName() << "("; + auto argTypes = funcType.getInputs(); + auto &entryBlock = funcOp.getBody().front(); + for (unsigned i = 0; i < argTypes.size(); ++i) { + if (i > 0) os << ", "; + auto typeStr = convertType(argTypes[i], /*asArg=*/true); + if (failed(typeStr)) + return funcOp.emitOpError("failed to convert argument type"); + os << *typeStr << " " << getName(entryBlock.getArgument(i)); + } + os << ") {\n"; + os.indent(); + + for (Block &block : funcOp.getBody()) { + for (Operation &op : block.getOperations()) { + if (failed(translate(op))) return failure(); + } + } + + os.unindent(); + os << "}\n\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(func::ReturnOp op) { + if (op.getNumOperands() == 0) { + os << "return;\n"; + } else if (op.getNumOperands() == 1) { + os << "return " << getName(op.getOperand(0)) << ";\n"; + } else { + os << "return std::make_tuple("; + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (i > 0) os << ", "; + os << "std::move(" << getName(op.getOperand(i)) << ")"; + } + os << ");\n"; + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(func::CallOp op) { + if (op.getNumResults() == 1) { + auto typeStr = convertType(op.getResult(0).getType()); + if (failed(typeStr)) return op.emitOpError("failed to convert result type"); + os << *typeStr << " " << getName(op.getResult(0)) << " = "; + } else if (op.getNumResults() > 1) { + os << "auto ["; + for (unsigned i = 0; i < op.getNumResults(); ++i) { + if (i > 0) os << ", "; + os << getName(op.getResult(i)); + } + os << "] = "; + } + + os << op.getCallee() << "("; + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (i > 0) os << ", "; + os << getName(op.getOperand(i)); + } + os << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Arith ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(arith::ConstantOp op) { + auto result = op.getResult(); + auto attr = op.getValue(); + auto tensorType = dyn_cast(result.getType()); + + if (auto denseAttr = dyn_cast(attr)) { + auto name = getName(result); + auto typeStr = convertType(result.getType()); + if (failed(typeStr)) + return op.emitOpError("failed to convert dense constant type"); + if (tensorType && denseAttr.isSplat()) { + SmallString<16> str; + denseAttr.getSplatValue().toString(str, /*FormatPrecision=*/15); + os << *typeStr << " " << name << "(" << tensorType.getNumElements() + << ", " << str << ");\n"; + return success(); + } + os << *typeStr << " " << name << " = {"; + bool first = true; + for (auto val : denseAttr.getValues()) { + if (!first) os << ", "; + first = false; + SmallString<16> str; + val.toString(str, /*FormatPrecision=*/15); + os << str; + } + os << "};\n"; + return success(); + } + + if (auto floatAttr = dyn_cast(attr)) { + SmallString<16> str; + floatAttr.getValue().toString(str, /*FormatPrecision=*/15); + os << "double " << getName(result) << " = " << str << ";\n"; + return success(); + } + + if (auto denseIntAttr = dyn_cast(attr)) { + auto name = getName(result); + auto typeStr = convertType(result.getType()); + if (failed(typeStr)) + return op.emitOpError("failed to convert dense int constant type"); + if (tensorType && denseIntAttr.isSplat()) { + os << *typeStr << " " << name << "(" << tensorType.getNumElements() + << ", " << denseIntAttr.getSplatValue().getSExtValue() + << ");\n"; + return success(); + } + os << *typeStr << " " << name << " = {"; + bool first = true; + for (auto val : denseIntAttr.getValues()) { + if (!first) os << ", "; + first = false; + os << val.getSExtValue(); + } + os << "};\n"; + return success(); + } + + if (auto intAttr = dyn_cast(attr)) { + auto type = result.getType(); + auto typeStr = convertType(type); + if (failed(typeStr)) + return op.emitOpError("failed to convert constant type"); + os << *typeStr << " " << getName(result) << " = " + << intAttr.getValue().getSExtValue() << ";\n"; + return success(); + } + + return op.emitOpError("unsupported constant type for CHEDDAR emitter"); +} + +//===----------------------------------------------------------------------===// +// Setup ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(CreateContextOp op) { + os << "auto " << getName(op.getCtx()) << " = Context::Create(" + << getName(op.getParams()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(CreateUserInterfaceOp op) { + os << "UI " << getName(op.getUi()) << "(" << getName(op.getCtx()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(GetEncoderOp op) { + os << "auto& " << getName(op.getEncoder()) << " = " << getName(op.getCtx()) + << "->encoder_;\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(GetEvkMapOp op) { + os << "const auto& " << getName(op.getEvkMap()) << " = " + << getName(op.getUi()) << ".GetEvkMap();\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(GetMultKeyOp op) { + os << "const auto& " << getName(op.getKey()) << " = " << getName(op.getUi()) + << ".GetMultiplicationKey();\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(GetRotKeyOp op) { + auto dist = op.getDistanceAttr().getInt(); + if (dist == cheddar::kDynamicRotationKeyDistanceSentinel) { + // Dynamic rotation keys are looked up at the HRot call site, since the + // actual rotation distance is only available there at runtime. + return success(); + } + os << "const auto& " << getName(op.getKey()) << " = " << getName(op.getUi()) + << ".GetRotationKey(" << dist << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(GetConjKeyOp op) { + os << "const auto& " << getName(op.getKey()) << " = " << getName(op.getUi()) + << ".GetConjugationKey();\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(PrepareRotKeyOp op) { + auto dist = op.getDistanceAttr().getInt(); + auto maxLevel = op.getMaxLevelAttr().getInt(); + os << getName(op.getUi()) << ".PrepareRotationKey(" << dist << ", " + << maxLevel << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Encode / Encrypt / Decrypt +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(EncodeOp op) { + auto level = op.getLevelAttr().getInt(); + auto logScale = op.getScaleAttr().getInt(); + auto name = getName(op.getPlaintext()); + auto msgName = getName(op.getMessage()); + + if (logDefaultScale <= 0) { + return op.emitOpError( + "requires module attribute `cheddar.logDefaultScale`"); + } + + // Encode expects Complex; convert real inputs. + auto msgType = op.getMessage().getType(); + bool needsComplexConversion = false; + if (auto tensorType = dyn_cast(msgType)) { + auto elemType = tensorType.getElementType(); + needsComplexConversion = elemType.isF32() || elemType.isF64(); + } + + // Use Parameter::GetScale(level) for the runtime scale. Square for + // doubled scale (e.g. post-mult before rescale). + std::string ctxName; + if (auto funcOp = op->getParentOfType()) { + if (funcOp.getNumArguments() > 0) ctxName = getName(funcOp.getArgument(0)); + } + + std::string scaleExpr; + if (!ctxName.empty()) { + std::string baseScale = + ctxName + "->param_.GetScale(" + std::to_string(level) + ")"; + if (logScale < logDefaultScale || logScale % logDefaultScale != 0) { + return op.emitOpError() + << "requires CHEDDAR scale to be a positive integer multiple of " + "the module logDefaultScale"; + } + if (logScale > logDefaultScale) { + int multiplier = logScale / logDefaultScale; + scaleExpr = baseScale; + for (int i = 1; i < multiplier; ++i) scaleExpr += " * " + baseScale; + } else { + scaleExpr = baseScale; + } + } else { + // Fallback if context not found + scaleExpr = "pow(2.0, " + std::to_string(logScale) + ")"; + } + + os << "Pt " << name << ";\n"; + if (needsComplexConversion) { + std::string complexMsgName = name + "_complex"; + os << "std::vector " << complexMsgName << "(" << msgName + << ".begin(), " << msgName << ".end());\n"; + os << getName(op.getEncoder()) << ".Encode(" << name << ", " << level + << ", " << scaleExpr << ", " << complexMsgName << ");\n"; + } else { + os << getName(op.getEncoder()) << ".Encode(" << name << ", " << level + << ", " << scaleExpr << ", " << msgName << ");\n"; + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(EncodeConstantOp op) { + auto level = op.getLevelAttr().getInt(); + auto logScale = op.getScaleAttr().getInt(); + auto name = getName(op.getConstant()); + + if (logDefaultScale <= 0) { + return op.emitOpError( + "requires module attribute `cheddar.logDefaultScale`"); + } + + std::string ctxName; + if (auto funcOp = op->getParentOfType()) { + if (funcOp.getNumArguments() > 0) ctxName = getName(funcOp.getArgument(0)); + } + + std::string scaleExpr; + if (!ctxName.empty()) { + std::string baseScale = + ctxName + "->param_.GetScale(" + std::to_string(level) + ")"; + if (logScale < logDefaultScale || logScale % logDefaultScale != 0) { + return op.emitOpError() + << "requires CHEDDAR scale to be a positive integer multiple of " + "the module logDefaultScale"; + } + if (logScale > logDefaultScale) { + int multiplier = logScale / logDefaultScale; + scaleExpr = baseScale; + for (int i = 1; i < multiplier; ++i) scaleExpr += " * " + baseScale; + } else { + scaleExpr = baseScale; + } + } else { + scaleExpr = "pow(2.0, " + std::to_string(logScale) + ")"; + } + + os << "Const " << name << ";\n"; + os << getName(op.getEncoder()) << ".EncodeConstant(" << name << ", " << level + << ", " << scaleExpr << ", " << getName(op.getValue()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(DecodeOp op) { + auto msgType = op.getMessage().getType(); + bool needsRealConversion = false; + if (auto tensorType = dyn_cast(msgType)) { + auto elemType = tensorType.getElementType(); + needsRealConversion = elemType.isF32() || elemType.isF64(); + } + + if (needsRealConversion) { + // CHEDDAR Decode returns Complex; convert to double for float tensors + std::string complexName = getName(op.getMessage()) + "_complex"; + os << "std::vector " << complexName << ";\n"; + os << getName(op.getEncoder()) << ".Decode(" << complexName << ", " + << getName(op.getPlaintext()) << ");\n"; + os << "std::vector " << getName(op.getMessage()) << "(" + << complexName << ".size());\n"; + os << "for (size_t i = 0; i < " << complexName << ".size(); ++i) " + << getName(op.getMessage()) << "[i] = " << complexName + << "[i].real();\n"; + } else { + auto name = getName(op.getMessage()); + os << "std::vector " << name << ";\n"; + os << getName(op.getEncoder()) << ".Decode(" << name << ", " + << getName(op.getPlaintext()) << ");\n"; + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(EncryptOp op) { + auto name = getName(op.getCiphertext()); + os << "Ct " << name << ";\n"; + os << getName(op.getUi()) << ".Encrypt(" << name << ", " + << getName(op.getPlaintext()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(DecryptOp op) { + auto name = getName(op.getPlaintext()); + os << "Pt " << name << ";\n"; + os << getName(op.getUi()) << ".Decrypt(" << name << ", " + << getName(op.getCiphertext()) << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Binary ct-ct ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(AddOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + emitScaleMismatchDebugCheck("add", name, op.getLhs(), op.getRhs()); + os << getName(op.getCtx()) << "->Add(" << name << ", " << getName(op.getLhs()) + << ", " << getName(op.getRhs()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(SubOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + emitScaleMismatchDebugCheck("sub", name, op.getLhs(), op.getRhs()); + os << getName(op.getCtx()) << "->Sub(" << name << ", " << getName(op.getLhs()) + << ", " << getName(op.getRhs()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(MultOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Mult(" << name << ", " + << getName(op.getLhs()) << ", " << getName(op.getRhs()) << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Ct-pt / ct-const ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(AddPlainOp op) { + auto name = getName(op.getOutput()); + auto ctName = getName(op.getCiphertext()); + auto ptName = getName(op.getPlaintext()); + os << "Ct " << name << ";\n"; + // Fix up plaintext scale to exactly match ciphertext scale. This is needed + // because CHEDDAR's LinearTransform produces scales via floating-point + // arithmetic that may differ slightly from GetScale(level). + os << ptName << ".SetScale(" << ctName << ".GetScale());\n"; + os << getName(op.getCtx()) << "->Add(" << name << ", " << ctName << ", " + << ptName << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(SubPlainOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + emitScaleMismatchDebugCheck("sub_plain", name, op.getCiphertext(), + op.getPlaintext()); + os << getName(op.getCtx()) << "->Sub(" << name << ", " + << getName(op.getCiphertext()) << ", " << getName(op.getPlaintext()) + << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(MultPlainOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Mult(" << name << ", " + << getName(op.getCiphertext()) << ", " << getName(op.getPlaintext()) + << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(AddConstOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Add(" << name << ", " + << getName(op.getCiphertext()) << ", " << getName(op.getConstant()) + << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(MultConstOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Mult(" << name << ", " + << getName(op.getCiphertext()) << ", " << getName(op.getConstant()) + << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Unary ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(NegOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Neg(" << name << ", " + << getName(op.getInput()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(RescaleOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Rescale(" << name << ", " + << getName(op.getInput()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(LevelDownOp op) { + auto name = getName(op.getOutput()); + auto level = op.getTargetLevelAttr().getInt(); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->LevelDown(" << name << ", " + << getName(op.getInput()) << ", " << level << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(RelinearizeOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Relinearize(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getMultKey()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(RelinearizeRescaleOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->RelinearizeRescale(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getMultKey()) << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Fused compound ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(HMultOp op) { + auto name = getName(op.getOutput()); + bool rescale = op.getRescale(); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->HMult(" << name << ", " + << getName(op.getLhs()) << ", " << getName(op.getRhs()) << ", " + << getName(op.getMultKey()) << ", " << (rescale ? "true" : "false") + << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(HRotOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + if (auto staticShift = op.getStaticShift()) { + // Static shift: use the pre-fetched rotation key + os << getName(op.getCtx()) << "->HRot(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getRotKey()) << ", " + << staticShift->getInt() << ");\n"; + } else { + // Dynamic shift: look up the rotation key at runtime. + auto shiftName = getName(op.getDynamicShift()); + // Find the UserInterface from the GetRotKeyOp's operand + auto getRotKeyOp = op.getRotKey().getDefiningOp(); + if (getRotKeyOp) { + if (getRotKeyOp.getDistanceAttr().getInt() != + cheddar::kDynamicRotationKeyDistanceSentinel) { + return op.emitOpError( + "dynamic rotation requires a sentinel GetRotKeyOp placeholder"); + } + os << getName(op.getCtx()) << "->HRot(" << name << ", " + << getName(op.getInput()) << ", " << getName(getRotKeyOp.getUi()) + << ".GetRotationKey(" << shiftName << "), " << shiftName << ");\n"; + } else { + return op.emitOpError( + "dynamic rotation requires GetRotKeyOp to trace back to " + "UserInterface"); + } + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(HRotAddOp op) { + auto name = getName(op.getOutput()); + auto dist = op.getDistanceAttr().getInt(); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->HRotAdd(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getAddend()) << ", " + << getName(op.getRotKey()) << ", " << dist << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(HConjOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->HConj(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getConjKey()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(HConjAddOp op) { + auto name = getName(op.getOutput()); + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->HConjAdd(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getAddend()) << ", " + << getName(op.getConjKey()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(MadUnsafeOp op) { + auto name = getName(op.getOutput()); + // MadUnsafe is in-place: res += a * const + // We need to copy the accumulator first + os << "Ct " << name << ";\n"; + os << getName(op.getCtx()) << "->Copy(" << name << ", " + << getName(op.getAccumulator()) << ");\n"; + os << getName(op.getCtx()) << "->MadUnsafe(" << name << ", " + << getName(op.getInput()) << ", " << getName(op.getConstant()) << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Extension ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(BootOp op) { + needsExtensionIncludes = true; + return op.emitOpError( + "CHEDDAR bootstrapping emission not yet implemented " + "(requires PrepareEvalMod and PrepareEvalSpecialFFT)"); +} + +LogicalResult CheddarEmitter::printOperation(LinearTransformOp op) { + needsExtensionIncludes = true; + auto name = getName(op.getOutput()); + auto diagIndices = op.getDiagonalIndicesAttr().asArrayRef(); + auto level = op.getLevelAttr().getInt(); + auto logBSGS = op.getLogBabyStepGiantStepRatioAttr().getInt(); + auto diagonalsName = getName(op.getDiagonals()); + auto ctxName = getName(op.getCtx()); + auto evkMapName = getName(op.getEvkMap()); + auto inputName = getName(op.getInput()); + + // Get diagonals tensor shape to determine slots + auto diagType = cast(op.getDiagonals().getType()); + int64_t numDiags = diagType.getShape()[0]; + int64_t slots = diagType.getShape()[1]; + + // Compute baby-step / giant-step from logBSGS ratio. + // If diagonal indices wrap around the slot range (some > slots/2), compute + // a pre_rotation to shift them into a contiguous range for BSGS. + int64_t preRotation = 0; + bool hasWrapAround = false; + int64_t minNeg = slots; // smallest negative-rotation index + for (auto idx : diagIndices) { + if (idx > slots / 2) { + hasWrapAround = true; + minNeg = std::min(minNeg, static_cast(idx)); + } + } + if (hasWrapAround) preRotation = minNeg; + + // Use the same BSGS computation as the rotation-index analysis so that + // the prepared rotation keys match what CHEDDAR actually requests. + // Compute on the pre-rotation-adjusted diagonal indices. + SmallVector adjustedDiags; + for (auto idx : diagIndices) { + adjustedDiags.push_back( + static_cast((idx - preRotation + slots) % slots)); + } + int64_t bs = + (logBSGS > 0) ? findBestBSGSRatio(adjustedDiags, slots, logBSGS) : 1; + int64_t gs = (bs > 1) ? (numDiags + bs - 1) / bs : numDiags; + + // Unique name for this linear transform + std::string ltName = name + "_lt"; + std::string matName = name + "_mat"; + + // Build the StripedMatrix from the diagonals tensor. CHEDDAR's + // LinearTransform requires a square StripedMatrix whose logical dimensions + // match the slot count; only the non-zero diagonals are actually stored. + os << "StripedMatrix " << matName << "(" << slots << ", " << slots << ");\n"; + os << "{\n"; + os.indent(); + os << "auto* data = " << diagonalsName << ".data();\n"; + for (int64_t i = 0; i < numDiags; ++i) { + os << matName << "[" << diagIndices[i] << "] = std::vector(data + " + << i * slots << ", data + " << (i + 1) * slots << ");\n"; + } + os.unindent(); + os << "}\n"; + + // Construct LinearTransform. CHEDDAR's LT internally rescales (output at + // pt_level - 1), so the diagonal plaintext scale should match the OUTPUT + // level's scale, not the input level's scale. + os << "LinearTransform " << ltName << "(" << ctxName << ", " << matName + << ", " << level << ", " << ctxName << "->param_.GetScale(" << level + << " - 1), " << bs << ", " << gs << ", " << preRotation << ");\n"; + + // Evaluate + os << "Ct " << name << ";\n"; + os << ltName << ".Evaluate(" << ctxName << ", " << name << ", " << inputName + << ", " << evkMapName << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(EvalPolyOp op) { + needsExtensionIncludes = true; + auto name = getName(op.getOutput()); + auto ctxName = getName(op.getCtx()); + auto evkMapName = getName(op.getEvkMap()); + auto inputName = getName(op.getInput()); + auto coeffsAttr = op.getCoefficientsAttr(); + + // Emit coefficient vector + std::string coeffsName = name + "_coeffs"; + os << "std::vector " << coeffsName << " = {"; + for (unsigned i = 0; i < coeffsAttr.size(); ++i) { + if (i > 0) os << ", "; + auto floatAttr = dyn_cast(coeffsAttr[i]); + if (!floatAttr) { + return op.emitOpError() + << "requires all polynomial coefficients to be FloatAttr"; + } + SmallString<16> str; + floatAttr.getValue().toString(str, /*FormatPrecision=*/15); + os << str; + } + os << "};\n"; + + // Get the mult key from the UI via the evkMap's defining GetEvkMapOp + std::string multKeyExpr; + if (auto getEvkMapOp = op.getEvkMap().getDefiningOp()) { + multKeyExpr = getName(getEvkMapOp.getUi()) + ".GetMultiplicationKey()"; + } else { + return op.emitOpError( + "could not trace evkMap back to UserInterface for mult key"); + } + + auto level = op.getLevelAttr().getInt(); + std::string scaleExpr = + ctxName + "->param_.GetScale(" + std::to_string(level) + ")"; + + std::string epName = name + "_ep"; + os << "EvalPoly " << epName << "(" << coeffsName << ", " << level + << ", " << scaleExpr << ", " << scaleExpr << ", false);\n"; + os << epName << ".Compile(" << ctxName << ");\n"; + os << "Ct " << name << ";\n"; + os << epName << ".Evaluate(" << ctxName << ", " << name << ", " << inputName + << ", " << multKeyExpr << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// SCF control flow ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(scf::ForOp op) { + // Initialize iter args + for (unsigned i = 0; i < op.getNumRegionIterArgs(); ++i) { + Value result = op.getResults()[i]; + Value init = op.getInitArgs()[i]; + Value iterArg = op.getRegionIterArgs()[i]; + if (variableNames->contains(result) && variableNames->contains(init) && + getName(result) == getName(init)) + continue; + auto typeStr = convertType(iterArg.getType()); + if (failed(typeStr)) + return op.emitOpError("failed to convert iter arg type"); + os << *typeStr << " " << getName(iterArg) << " = std::move(" + << getName(init) << ");\n"; + } + + // for (int64_t iv = lb; iv < ub; iv += step) + auto getVal = [&](Value v) -> std::string { + if (auto constOp = v.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) + return std::to_string(intAttr.getInt()); + } + return getName(v); + }; + os << "for (int64_t " << getName(op.getInductionVar()) << " = " + << getVal(op.getLowerBound()) << "; " << getName(op.getInductionVar()) + << " < " << getVal(op.getUpperBound()) << "; " + << getName(op.getInductionVar()) << " += " << getVal(op.getStep()) + << ") {\n"; + os.indent(); + for (Operation &bodyOp : *op.getBody()) { + if (failed(translate(bodyOp))) return failure(); + } + os.unindent(); + os << "}\n"; + + // Forward iter args to results + for (unsigned i = 0; i < op.getNumResults(); ++i) { + Value opResult = op.getResult(i); + Value iterArg = op.getRegionIterArg(i); + if (getName(opResult) != getName(iterArg)) { + auto typeStr = convertType(opResult.getType()); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(opResult) << " = std::move(" + << getName(iterArg) << ");\n"; + } + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(scf::IfOp op) { + // Declare result variables + for (unsigned i = 0; i < op.getNumResults(); ++i) { + auto typeStr = convertType(op.getResults()[i].getType()); + if (failed(typeStr)) + return op.emitOpError("failed to convert if result type"); + os << *typeStr << " " << getName(op.getResults()[i]) << ";\n"; + } + + os << "if (" << getName(op.getCondition()) << ") {\n"; + os.indent(); + for (Operation &thenOp : *op.thenBlock()) { + if (failed(translate(thenOp))) return failure(); + } + os.unindent(); + os << "} else {\n"; + os.indent(); + for (Operation &elseOp : *op.elseBlock()) { + if (failed(translate(elseOp))) return failure(); + } + os.unindent(); + os << "}\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(scf::YieldOp op) { + ValueRange destValues = + llvm::TypeSwitch(op->getParentOp()) + .Case( + [&](auto forOp) { return forOp.getRegionIterArgs(); }) + .Case([&](auto ifOp) { return ifOp.getResults(); }) + .Default([&](auto) { return ValueRange{}; }); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + os << getName(destValues[i]) << " = std::move(" << getName(op.getOperand(i)) + << ");\n"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Tensor ops (tensors are flattened to std::vector) +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(tensor::EmptyOp op) { + auto resultType = op.getResult().getType(); + auto elemType = resultType.getElementType(); + auto typeStr = convertType(elemType); + if (failed(typeStr)) return op.emitOpError("failed to convert element type"); + bool isMoveOnly = isa(elemType); + if (isMoveOnly) { + // Move-only: reserve without constructing (will be filled via insert) + os << "std::vector<" << *typeStr << "> " << getName(op.getResult()) + << ";\n"; + os << getName(op.getResult()) << ".resize(" << resultType.getNumElements() + << ");\n"; + } else { + os << "std::vector<" << *typeStr << "> " << getName(op.getResult()) << "(" + << resultType.getNumElements() << ");\n"; + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::ExtractOp op) { + auto resultType = op.getResult().getType(); + bool isMoveOnly = + isa(resultType); + if (isMoveOnly) { + // Reference to avoid copying move-only types + os << "auto& " << getName(op.getResult()) << " = " + << getName(op.getTensor()) << "["; + } else { + auto typeStr = convertType(resultType); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(op.getResult()) << " = " + << getName(op.getTensor()) << "["; + } + os << flattenIndexExpression(op.getTensor().getType(), op.getIndices(), + [&](Value value) { return getName(value); }); + os << "];\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::InsertOp op) { + bool inPlace = variableNames->contains(op.getResult()) && + variableNames->contains(op.getDest()) && + getName(op.getResult()) == getName(op.getDest()); + auto resultType = cast(op.getResult().getType()); + auto elemType = resultType.getElementType(); + bool isMoveOnly = isa(elemType); + bool destIsFreshEmpty = + isa_and_nonnull(op.getDest().getDefiningOp()); + bool canMoveDest = + canMoveTensorDestIntoResult(op.getDest(), op.getOperation()); + + std::string resultName; + if (inPlace) { + resultName = getName(op.getResult()); + } else { + // Value semantics: dest is immutable, result is a modified copy. + resultName = getName(op.getResult()); + if (isMoveOnly) { + if (destIsFreshEmpty) { + auto typeResult = convertType(elemType); + if (failed(typeResult)) { + return op.emitOpError("failed to convert element type"); + } + auto resultType = cast(op.getResult().getType()); + os << "std::vector<" << *typeResult << "> " << resultName << ";\n"; + os << resultName << ".resize(" << resultType.getNumElements() << ");\n"; + } else if (canMoveDest) { + os << "auto " << resultName << " = std::move(" << getName(op.getDest()) + << ");\n"; + } else if (isContextCopyableMoveOnlyType(elemType)) { + if (failed(emitVectorDeepCopy(op, resultName, getName(op.getDest()), + elemType, op.getDest().getType()))) { + return failure(); + } + } else { + return op.emitOpError( + "cannot duplicate move-only plaintext/constant tensor destination"); + } + } else { + if (destIsFreshEmpty) { + auto typeResult = convertType(elemType); + if (failed(typeResult)) { + return op.emitOpError("failed to convert element type"); + } + os << "std::vector<" << *typeResult << "> " << resultName << "(" + << resultType.getNumElements() << ");\n"; + } else if (canMoveDest) { + os << "auto " << resultName << " = std::move(" << getName(op.getDest()) + << ");\n"; + } else { + os << "auto " << resultName << " = " << getName(op.getDest()) << ";\n"; + } + } + } + + auto scalarName = getName(op.getScalar()); + std::string idxExpr = + flattenIndexExpression(op.getResult().getType(), op.getIndices(), + [&](Value value) { return getName(value); }); + if (isMoveOnly) { + if (isContextCopyableMoveOnlyType(elemType)) { + auto copyName = scalarName + "_c" + std::to_string(tempVarCounter++); + auto typeResult = convertType(elemType); + if (failed(typeResult)) { + return op.emitOpError("failed to convert CHEDDAR tensor element type"); + } + auto ctxName = getContextName(op); + if (ctxName.empty()) { + return op.emitOpError( + "missing CHEDDAR context argument for tensor.insert copy"); + } + os << *typeResult << " " << copyName << ";\n"; + os << ctxName << "->Copy(" << copyName << ", " << scalarName << ");\n"; + os << resultName << "[" << idxExpr << "] = std::move(" << copyName + << ");\n"; + } else if (canMoveValueIntoConsumer(op.getScalar())) { + os << resultName << "[" << idxExpr << "] = std::move(" << scalarName + << ");\n"; + } else { + return op.emitOpError( + "cannot duplicate move-only plaintext/constant scalar for " + "tensor.insert"); + } + } else { + os << resultName << "[" << idxExpr << "] = " << scalarName << ";\n"; + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::FromElementsOp op) { + auto elemType = getElementTypeOrSelf(op.getResult().getType()); + auto typeStr = convertType(elemType); + if (failed(typeStr)) return failure(); + bool isMoveOnly = isa(elemType); + if (isMoveOnly) { + os << "std::vector<" << *typeStr << "> " << getName(op.getResult()) + << ";\n"; + os << getName(op.getResult()) << ".reserve(" << op.getNumOperands() + << ");\n"; + auto ctxName = getContextName(op); + if (ctxName.empty() && isContextCopyableMoveOnlyType(elemType)) { + return op.emitOpError( + "missing CHEDDAR context argument for tensor.from_elements copy"); + } + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto operandName = getName(op.getOperand(i)); + if (isContextCopyableMoveOnlyType(elemType)) { + auto copyName = operandName + "_c" + std::to_string(tempVarCounter++); + os << *typeStr << " " << copyName << ";\n"; + os << ctxName << "->Copy(" << copyName << ", " << operandName << ");\n"; + os << getName(op.getResult()) << ".emplace_back(std::move(" << copyName + << "));\n"; + } else if (canMoveValueIntoConsumer(op.getOperand(i))) { + os << getName(op.getResult()) << ".emplace_back(std::move(" + << operandName << "));\n"; + } else { + return op.emitOpError( + "cannot duplicate move-only plaintext/constant for " + "tensor.from_elements"); + } + } + } else { + os << "std::vector<" << *typeStr << "> " << getName(op.getResult()) + << " = {"; + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (i > 0) os << ", "; + os << getName(op.getOperand(i)); + } + os << "};\n"; + } + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::SplatOp op) { + auto resultType = op.getResult().getType(); + auto typeStr = convertType(resultType.getElementType()); + if (failed(typeStr)) return op.emitOpError("failed to convert element type"); + os << "std::vector<" << *typeStr << "> " << getName(op.getResult()) << "(" + << resultType.getNumElements() << ", " << getName(op.getInput()) << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::ExpandShapeOp op) { + // Tensors are flat — expand_shape is a no-op alias. + SliceVerificationResult res = + isRankReducedType(op.getResultType(), op.getSrcType()); + if (res != SliceVerificationResult::Success) { + return op.emitError() + << "Only rank-reduced types are supported for ExpandShapeOp"; + } + variableNames->mapValueNameToValue(op.getResult(), op.getSrc()); + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::ExtractSliceOp op) { + auto resultType = op.getResultType(); + auto typeStr = convertType(resultType.getElementType()); + if (failed(typeStr)) return failure(); + + auto getOffsetStr = [&](OpFoldResult ofr) -> std::string { + if (auto attr = dyn_cast(ofr)) + return std::to_string(cast(attr).getInt()); + return getName(cast(ofr)); + }; + + // Simple contiguous case: just take a sub-span + auto offsets = op.getMixedOffsets(); + auto sizes = op.getMixedSizes(); + int64_t numElements = resultType.getNumElements(); + + auto srcType = op.getSourceType(); + auto emitFlatOffset = [&]() { + for (unsigned i = 0; i < offsets.size(); ++i) { + if (i > 0) os << " + "; + int64_t stride = 1; + for (unsigned j = i + 1; j < srcType.getRank(); ++j) + stride *= srcType.getDimSize(j); + os << getOffsetStr(offsets[i]); + if (stride != 1) os << " * " << stride; + } + }; + os << "std::vector<" << *typeStr << "> " << getName(op.getResult()) << "(" + << getName(op.getSource()) << ".begin() + "; + emitFlatOffset(); + os << ", " << getName(op.getSource()) << ".begin() + "; + emitFlatOffset(); + os << " + " << numElements << ");\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(tensor::InsertSliceOp op) { + bool inPlace = variableNames->contains(op.getResult()) && + variableNames->contains(op.getDest()) && + getName(op.getResult()) == getName(op.getDest()); + auto elemType = op.getType().getElementType(); + bool isMoveOnly = isa(elemType); + bool destIsFreshEmpty = + isa_and_nonnull(op.getDest().getDefiningOp()); + bool canMoveDest = + canMoveTensorDestIntoResult(op.getDest(), op.getOperation()); + + std::string resultName; + if (inPlace) { + resultName = getName(op.getResult()); + } else { + resultName = getName(op.getResult()); + if (isMoveOnly) { + if (destIsFreshEmpty) { + auto typeResult = convertType(elemType); + if (failed(typeResult)) + return op.emitOpError("failed to convert element type"); + os << "std::vector<" << *typeResult << "> " << resultName << ";\n"; + os << resultName << ".resize(" << op.getType().getNumElements() + << ");\n"; + } else if (canMoveDest) { + os << "auto " << resultName << " = std::move(" << getName(op.getDest()) + << ");\n"; + } else if (isContextCopyableMoveOnlyType(elemType)) { + if (failed(emitVectorDeepCopy(op, resultName, getName(op.getDest()), + elemType, op.getDest().getType()))) { + return failure(); + } + } else { + return op.emitOpError( + "cannot duplicate move-only plaintext/constant tensor destination"); + } + } else { + if (destIsFreshEmpty) { + auto typeResult = convertType(elemType); + if (failed(typeResult)) + return op.emitOpError("failed to convert element type"); + os << "std::vector<" << *typeResult << "> " << resultName << "(" + << op.getType().getNumElements() << ");\n"; + } else if (canMoveDest) { + os << "auto " << resultName << " = std::move(" << getName(op.getDest()) + << ");\n"; + } else { + os << "auto " << resultName << " = " << getName(op.getDest()) << ";\n"; + } + } + } + // Copy source into dest at offset + auto offsets = op.getMixedOffsets(); + auto getOffsetStr = [&](OpFoldResult ofr) -> std::string { + if (auto attr = dyn_cast(ofr)) + return std::to_string(cast(attr).getInt()); + return getName(cast(ofr)); + }; + auto destType = op.getType(); + if (isMoveOnly && !isContextCopyableMoveOnlyType(elemType) && + !canMoveValueIntoConsumer(op.getSource())) { + return op.emitOpError( + "cannot duplicate move-only plaintext/constant tensor slice"); + } + os << (isMoveOnly ? "std::move(" : "std::copy(") << getName(op.getSource()) + << ".begin(), " << getName(op.getSource()) << ".end(), " << resultName + << ".begin() + "; + for (unsigned i = 0; i < offsets.size(); ++i) { + if (i > 0) os << " + "; + int64_t stride = 1; + for (unsigned j = i + 1; j < destType.getRank(); ++j) + stride *= destType.getDimSize(j); + os << getOffsetStr(offsets[i]); + if (stride != 1) os << " * " << stride; + } + os << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Additional arith ops +//===----------------------------------------------------------------------===// + +LogicalResult CheddarEmitter::printOperation(arith::AddIOp op) { + auto typeStr = convertType(op.getResult().getType()); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(op.getResult()) << " = " + << getName(op.getLhs()) << " + " << getName(op.getRhs()) << ";\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(arith::MulIOp op) { + auto typeStr = convertType(op.getResult().getType()); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(op.getResult()) << " = " + << getName(op.getLhs()) << " * " << getName(op.getRhs()) << ";\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(arith::SubIOp op) { + auto typeStr = convertType(op.getResult().getType()); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(op.getResult()) << " = " + << getName(op.getLhs()) << " - " << getName(op.getRhs()) << ";\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(arith::FloorDivSIOp op) { + auto typeStr = convertType(op.getResult().getType()); + if (failed(typeStr)) return failure(); + auto lhs = getName(op.getLhs()); + auto rhs = getName(op.getRhs()); + os << *typeStr << " " << getName(op.getResult()) << " = (" << lhs << " / " + << rhs << ") - ((" << lhs << " % " << rhs << " != 0) && ((" << lhs + << " < 0) != (" << rhs << " < 0)));\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(arith::RemSIOp op) { + auto typeStr = convertType(op.getResult().getType()); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(op.getResult()) << " = " + << getName(op.getLhs()) << " % " << getName(op.getRhs()) << ";\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(arith::CmpIOp op) { + std::string cmpOp; + switch (op.getPredicate()) { + case arith::CmpIPredicate::eq: + cmpOp = "=="; + break; + case arith::CmpIPredicate::ne: + cmpOp = "!="; + break; + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::ult: + cmpOp = "<"; + break; + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::ule: + cmpOp = "<="; + break; + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ugt: + cmpOp = ">"; + break; + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::uge: + cmpOp = ">="; + break; + } + os << "bool " << getName(op.getResult()) << " = " << getName(op.getLhs()) + << " " << cmpOp << " " << getName(op.getRhs()) << ";\n"; + return success(); +} + +LogicalResult CheddarEmitter::printOperation(arith::IndexCastOp op) { + auto typeStr = convertType(op.getOut().getType()); + if (failed(typeStr)) return failure(); + os << *typeStr << " " << getName(op.getOut()) << " = static_cast<" << *typeStr + << ">(" << getName(op.getIn()) << ");\n"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Translation registration +//===----------------------------------------------------------------------===// + +LogicalResult translateToCheddar(Operation *op, llvm::raw_ostream &os, + bool use64Bit, const std::string ¶msJson) { + SelectVariableNames variableNames(op); + // Two-pass: buffer body, then emit prelude + body + std::string bufferedStr; + llvm::raw_string_ostream strOs(bufferedStr); + CheddarEmitter emitter(strOs, &variableNames, use64Bit, paramsJson); + LogicalResult result = emitter.translate(*op); + if (failed(result)) return result; + + emitter.emitPrelude(os); + os << strOs.str(); + return success(); +} + +struct CheddarTranslateOptions { + llvm::cl::opt use64Bit{"cheddar-use-64bit", + llvm::cl::desc("Use uint64_t word type"), + llvm::cl::init(true)}; + llvm::cl::opt paramsJson{ + "cheddar-params-json", + llvm::cl::desc("Path to CHEDDAR parameter JSON file"), + llvm::cl::init("")}; +}; +static llvm::ManagedStatic cheddarTranslateOptions; + +void registerCheddarTranslateOptions() { *cheddarTranslateOptions; } + +static void registerRelevantDialects(DialectRegistry ®istry) { + registry.insert(); +} + +void registerToCheddarTranslation() { + TranslateFromMLIRRegistration reg( + "emit-cheddar", + "translate the cheddar dialect to C++ code against the CHEDDAR API", + [](Operation *op, llvm::raw_ostream &output) { + return translateToCheddar(op, output, cheddarTranslateOptions->use64Bit, + cheddarTranslateOptions->paramsJson); + }, + registerRelevantDialects); +} + +LogicalResult translateToCheddarHeader(Operation *op, llvm::raw_ostream &os, + bool use64Bit) { + // Emit a proper C++ header: includes, type aliases, function declarations. + auto moduleOp = dyn_cast(op); + if (!moduleOp) return failure(); + + SelectVariableNames variableNames(op); + + os << "#pragma once\n"; + os << kStdIncludes; + os << kCheddarInclude; + + // Check if extensions are needed + bool needsExtension = false; + moduleOp->walk([&](Operation *innerOp) { + if (isa(innerOp)) + needsExtension = true; + }); + if (needsExtension) os << kCheddarExtensionInclude; + + os << "\n"; + if (use64Bit) { + os << kTypeAliasPrelude64; + } else { + os << kTypeAliasPrelude32; + } + os << "\n"; + + // Emit function declarations + for (auto funcOp : moduleOp.getOps()) { + auto funcType = funcOp.getFunctionType(); + auto resultTypes = funcType.getResults(); + auto argTypes = funcType.getInputs(); + + // Two-pass: buffer to determine return type string + CheddarEmitter tempEmitter(llvm::nulls(), &variableNames, use64Bit); + + // Return type + if (resultTypes.empty()) { + os << "void "; + } else if (resultTypes.size() == 1) { + auto typeStr = tempEmitter.convertType(resultTypes[0]); + if (failed(typeStr)) return failure(); + os << *typeStr << " "; + } else { + os << "std::tuple<"; + for (unsigned i = 0; i < resultTypes.size(); ++i) { + if (i > 0) os << ", "; + auto typeStr = tempEmitter.convertType(resultTypes[i]); + if (failed(typeStr)) return failure(); + os << *typeStr; + } + os << "> "; + } + + // Function name and arg types + os << funcOp.getName() << "("; + for (unsigned i = 0; i < argTypes.size(); ++i) { + if (i > 0) os << ", "; + auto typeStr = tempEmitter.convertType(argTypes[i], /*asArg=*/true); + if (failed(typeStr)) return failure(); + os << *typeStr; + } + os << ");\n"; + } + + // Emit __configure declaration if module has the required params + bool hasLogN = + moduleOp->getAttrOfType("cheddar.logN") != nullptr; + bool hasQ = + moduleOp->getAttrOfType("cheddar.Q") != nullptr; + if (hasLogN && hasQ) { + os << "\nstd::tuple __configure();\n"; + } + + return success(); +} + +void registerToCheddarHeaderTranslation() { + TranslateFromMLIRRegistration reg( + "emit-cheddar-header", + "translate the cheddar dialect to a C++ header file", + [](Operation *op, llvm::raw_ostream &output) { + return translateToCheddarHeader(op, output, + cheddarTranslateOptions->use64Bit); + }, + registerRelevantDialects); +} + +} // namespace cheddar +} // namespace heir +} // namespace mlir diff --git a/lib/Target/Cheddar/CheddarEmitter.h b/lib/Target/Cheddar/CheddarEmitter.h new file mode 100644 index 0000000000..699623aa52 --- /dev/null +++ b/lib/Target/Cheddar/CheddarEmitter.h @@ -0,0 +1,162 @@ +#ifndef LIB_TARGET_CHEDDAR_CHEDDAREMITTER_H_ +#define LIB_TARGET_CHEDDAR_CHEDDAREMITTER_H_ + +#include +#include +#include +#include + +#include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" +#include "lib/Dialect/Cheddar/IR/CheddarOps.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace cheddar { + +/// Emits C++ code targeting the CHEDDAR GPU FHE library API. +class CheddarEmitter { + public: + CheddarEmitter(raw_ostream &os, SelectVariableNames *variableNames, + bool use64Bit = true, const std::string ¶msJsonPath = ""); + + LogicalResult translate(Operation &operation); + + void emitPrelude(raw_ostream &os) const; + + private: + raw_indented_ostream os; + SelectVariableNames *variableNames; + bool use64Bit; + std::string paramsJsonPath; + + // Track which include groups are needed + bool needsExtensionIncludes = false; + bool needsJsonIncludes = false; + + // Counter for generating unique temporary variable names (e.g., for copies + // of multi-use move-only values). + int tempVarCounter = 0; + + // Module-level scheme params (read from cheddar.* attrs) + int64_t logDefaultScale = -1; + int64_t logN = 0; + std::vector qPrimes; + std::vector pPrimes; + std::set rotationDistances; + + // Per-op printers + LogicalResult printOperation(ModuleOp op); + LogicalResult printOperation(func::FuncOp op); + LogicalResult printOperation(func::ReturnOp op); + LogicalResult printOperation(func::CallOp op); + + // Cheddar dialect ops + LogicalResult printOperation(CreateContextOp op); + LogicalResult printOperation(CreateUserInterfaceOp op); + LogicalResult printOperation(GetEncoderOp op); + LogicalResult printOperation(GetEvkMapOp op); + LogicalResult printOperation(GetMultKeyOp op); + LogicalResult printOperation(GetRotKeyOp op); + LogicalResult printOperation(GetConjKeyOp op); + LogicalResult printOperation(PrepareRotKeyOp op); + + LogicalResult printOperation(EncodeOp op); + LogicalResult printOperation(EncodeConstantOp op); + LogicalResult printOperation(DecodeOp op); + LogicalResult printOperation(EncryptOp op); + LogicalResult printOperation(DecryptOp op); + + LogicalResult printOperation(AddOp op); + LogicalResult printOperation(SubOp op); + LogicalResult printOperation(MultOp op); + LogicalResult printOperation(NegOp op); + + LogicalResult printOperation(AddPlainOp op); + LogicalResult printOperation(SubPlainOp op); + LogicalResult printOperation(MultPlainOp op); + LogicalResult printOperation(AddConstOp op); + LogicalResult printOperation(MultConstOp op); + + LogicalResult printOperation(RescaleOp op); + LogicalResult printOperation(LevelDownOp op); + LogicalResult printOperation(RelinearizeOp op); + LogicalResult printOperation(RelinearizeRescaleOp op); + + LogicalResult printOperation(HMultOp op); + LogicalResult printOperation(HRotOp op); + LogicalResult printOperation(HRotAddOp op); + LogicalResult printOperation(HConjOp op); + LogicalResult printOperation(HConjAddOp op); + LogicalResult printOperation(MadUnsafeOp op); + + LogicalResult printOperation(BootOp op); + LogicalResult printOperation(LinearTransformOp op); + LogicalResult printOperation(EvalPolyOp op); + + // Arith dialect ops + LogicalResult printOperation(arith::ConstantOp op); + LogicalResult printOperation(arith::AddIOp op); + LogicalResult printOperation(arith::MulIOp op); + LogicalResult printOperation(arith::SubIOp op); + LogicalResult printOperation(arith::FloorDivSIOp op); + LogicalResult printOperation(arith::RemSIOp op); + LogicalResult printOperation(arith::CmpIOp op); + LogicalResult printOperation(arith::IndexCastOp op); + + // SCF ops + LogicalResult printOperation(scf::ForOp op); + LogicalResult printOperation(scf::IfOp op); + LogicalResult printOperation(scf::YieldOp op); + + // Tensor ops + LogicalResult printOperation(tensor::EmptyOp op); + LogicalResult printOperation(tensor::ExtractOp op); + LogicalResult printOperation(tensor::InsertOp op); + LogicalResult printOperation(tensor::SplatOp op); + LogicalResult printOperation(tensor::FromElementsOp op); + LogicalResult printOperation(tensor::ExpandShapeOp op); + LogicalResult printOperation(tensor::ExtractSliceOp op); + LogicalResult printOperation(tensor::InsertSliceOp op); + + public: + // Type conversion (public for header emission) + // asArg=true emits const-ref for move-only CHEDDAR types (function params) + // asArg=false emits by-value (return types, local variables) + FailureOr convertType(Type type, bool asArg = false); + + private: + std::string getName(Value value); + std::string getContextName(Operation *op); + void emitScaleMismatchDebugCheck(StringRef opKind, StringRef resultName, + Value lhs, Value rhs); + LogicalResult emitVectorDeepCopy(Operation *op, StringRef destName, + StringRef srcName, Type elemType, + Type tensorType); +}; + +/// Free functions for translation registration + +LogicalResult translateToCheddar(Operation *op, llvm::raw_ostream &os, + bool use64Bit, const std::string ¶msJson); + +void registerToCheddarTranslation(); +void registerToCheddarHeaderTranslation(); +void registerCheddarTranslateOptions(); + +} // namespace cheddar +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_CHEDDAR_CHEDDAREMITTER_H_ diff --git a/lib/Target/Cheddar/CheddarTemplates.h b/lib/Target/Cheddar/CheddarTemplates.h new file mode 100644 index 0000000000..690716abd8 --- /dev/null +++ b/lib/Target/Cheddar/CheddarTemplates.h @@ -0,0 +1,81 @@ +#ifndef LIB_TARGET_CHEDDAR_CHEDDARTEMPLATES_H_ +#define LIB_TARGET_CHEDDAR_CHEDDARTEMPLATES_H_ + +#include + +namespace mlir { +namespace heir { +namespace cheddar { + +// Includes emitted at the top of generated files +// CHEDDAR headers use unnamespaced paths (core/, extension/). +constexpr std::string_view kCheddarInclude = R"cpp( +#include "core/Context.h" +#include "core/Container.h" +#include "core/Parameter.h" +#include "core/Encode.h" +#include "core/EvkMap.h" +#include "core/EvkRequest.h" +#include "UserInterface.h" +)cpp"; + +constexpr std::string_view kCheddarExtensionInclude = R"cpp( +#include "extension/BootContext.h" +#include "extension/LinearTransform.h" +#include "extension/EvalPoly.h" +#include "extension/Hoist.h" +)cpp"; + +constexpr std::string_view kStdIncludes = R"cpp( +#include +#include +#include +#include +#include +#include +#include +#include +#include +)cpp"; + +constexpr std::string_view kJsonInclude = R"cpp( +#include +#include +)cpp"; + +// The type alias prelude, parametric on word type +constexpr std::string_view kTypeAliasPrelude64 = R"cpp( + using namespace cheddar; + using word = uint64_t; + using Ct = Ciphertext; + using Pt = Plaintext; + using Const = Constant; + using Evk = EvaluationKey; + using EvkMapT = EvkMap; + using CtxPtr = std::shared_ptr>; + using Param = Parameter; + using UI = UserInterface; + using Enc = Encoder; + using Complex = std::complex; +)cpp"; + +constexpr std::string_view kTypeAliasPrelude32 = R"cpp( + using namespace cheddar; + using word = uint32_t; + using Ct = Ciphertext; + using Pt = Plaintext; + using Const = Constant; + using Evk = EvaluationKey; + using EvkMapT = EvkMap; + using CtxPtr = std::shared_ptr>; + using Param = Parameter; + using UI = UserInterface; + using Enc = Encoder; + using Complex = std::complex; +)cpp"; + +} // namespace cheddar +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_CHEDDAR_CHEDDARTEMPLATES_H_ diff --git a/lib/Transforms/AnnotateModule/AnnotateModule.cpp b/lib/Transforms/AnnotateModule/AnnotateModule.cpp index c86dbc2983..a425c8da20 100644 --- a/lib/Transforms/AnnotateModule/AnnotateModule.cpp +++ b/lib/Transforms/AnnotateModule/AnnotateModule.cpp @@ -30,6 +30,8 @@ struct AnnotateModule : impl::AnnotateModuleBase { moduleSetOpenfhe(module); } else if (backend == "lattigo") { moduleSetLattigo(module); + } else if (backend == "cheddar") { + moduleSetCheddar(module); } } }; diff --git a/patches/cheddar.patch b/patches/cheddar.patch new file mode 100644 index 0000000000..699d9080d8 --- /dev/null +++ b/patches/cheddar.patch @@ -0,0 +1,70 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 323e218..cb095e9 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -6,7 +6,10 @@ set (CMAKE_CXX_STANDARD 17) + set (CMAKE_CUDA_STANDARD 17) + set (CMAKE_EXPORT_COMPILE_COMMANDS ON) + +-set (CMAKE_CUDA_ARCHITECTURES 60 61 70 75 80 86 89 90) ++# Respect caller-provided CMAKE_CUDA_ARCHITECTURES (e.g. `native` from bazel). ++if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) ++ set(CMAKE_CUDA_ARCHITECTURES 60 61 70 75 80 86 89 90) ++endif() + + # Automatically detect the CUDA architecture + #if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) +@@ -24,6 +27,7 @@ FetchContent_Declare( + GIT_REPOSITORY https://github.com/rapidsai/rmm + GIT_TAG branch-22.12 + GIT_SHALLOW ++ EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(rmm) + message(STATUS "RMM source dir: ${rmm_SOURCE_DIR}") +@@ -93,7 +97,7 @@ target_link_libraries(cheddar + ) + + target_include_directories(cheddar +- PUBLIC include ++ PUBLIC include ${CUDAToolkit_INCLUDE_DIRS} ${rmm_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/compat + ) + + target_compile_options(cheddar PRIVATE +@@ -112,3 +116,16 @@ if (BUILD_UNITTEST) + message(STATUS "Building unit tests") + add_subdirectory(unittest) + endif() ++ ++# Install targets for bazel compatibility. ++# CHEDDAR's public headers leak transitive deps (RMM, spdlog), so we ++# install all FetchContent-provided headers alongside CHEDDAR's own. ++install(TARGETS cheddar LIBRARY DESTINATION lib) ++install(DIRECTORY include/ DESTINATION include) ++install(DIRECTORY compat/ DESTINATION include) ++file(GLOB _fetchcontent_dirs "${CMAKE_BINARY_DIR}/_deps/*-src/include") ++foreach(_dir ${_fetchcontent_dirs}) ++ if(IS_DIRECTORY ${_dir}) ++ install(DIRECTORY ${_dir}/ DESTINATION include) ++ endif() ++endforeach() +diff --git a/compat/thrust/optional.h b/compat/thrust/optional.h +new file mode 100644 +index 0000000..1111111 +--- /dev/null ++++ b/compat/thrust/optional.h +@@ -0,0 +1,14 @@ ++// Compatibility shim: thrust::optional was removed in CCCL 3.x (CUDA 13). ++// RMM branch-22.12 (pinned by cheddar's CMakeLists) still references ++// . This shim redirects to cuda::std::optional. ++// ++// Delete this file once RMM is bumped to a version that no longer uses ++// thrust::optional (released alongside CCCL 3.x). ++#pragma once ++ ++#include ++ ++namespace thrust { ++using ::cuda::std::nullopt; ++using ::cuda::std::optional; ++} // namespace thrust diff --git a/tests/Dialect/Cheddar/IR/BUILD b/tests/Dialect/Cheddar/IR/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/Cheddar/IR/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/Cheddar/IR/roundtrip.mlir b/tests/Dialect/Cheddar/IR/roundtrip.mlir new file mode 100644 index 0000000000..0b4485573b --- /dev/null +++ b/tests/Dialect/Cheddar/IR/roundtrip.mlir @@ -0,0 +1,368 @@ +// RUN: heir-opt %s | FileCheck %s + +// Test that the cheddar dialect can be parsed and printed. + +// --- Setup operations --- + +// CHECK: @test_create_context +func.func @test_create_context(%params: !cheddar.parameter) -> !cheddar.context { + // CHECK: cheddar.create_context + %ctx = cheddar.create_context %params : (!cheddar.parameter) -> !cheddar.context + return %ctx : !cheddar.context +} + +// CHECK: @test_create_user_interface +func.func @test_create_user_interface(%ctx: !cheddar.context) -> !cheddar.user_interface { + // CHECK: cheddar.create_user_interface + %ui = cheddar.create_user_interface %ctx : (!cheddar.context) -> !cheddar.user_interface + return %ui : !cheddar.user_interface +} + +// CHECK: @test_get_encoder +func.func @test_get_encoder(%ctx: !cheddar.context) -> !cheddar.encoder { + // CHECK: cheddar.get_encoder + %enc = cheddar.get_encoder %ctx : (!cheddar.context) -> !cheddar.encoder + return %enc : !cheddar.encoder +} + +// CHECK: @test_get_evk_map +func.func @test_get_evk_map(%ui: !cheddar.user_interface) -> !cheddar.evk_map { + // CHECK: cheddar.get_evk_map + %evk = cheddar.get_evk_map %ui : (!cheddar.user_interface) -> !cheddar.evk_map + return %evk : !cheddar.evk_map +} + +// CHECK: @test_get_mult_key +func.func @test_get_mult_key(%ui: !cheddar.user_interface) -> !cheddar.eval_key { + // CHECK: cheddar.get_mult_key + %key = cheddar.get_mult_key %ui : (!cheddar.user_interface) -> !cheddar.eval_key + return %key : !cheddar.eval_key +} + +// CHECK: @test_get_rot_key +func.func @test_get_rot_key(%ui: !cheddar.user_interface) -> !cheddar.eval_key { + // CHECK: cheddar.get_rot_key + // CHECK-SAME: distance = 5 + %key = cheddar.get_rot_key %ui {distance = 5 : i64} : (!cheddar.user_interface) -> !cheddar.eval_key + return %key : !cheddar.eval_key +} + +// CHECK: @test_get_conj_key +func.func @test_get_conj_key(%ui: !cheddar.user_interface) -> !cheddar.eval_key { + // CHECK: cheddar.get_conj_key + %key = cheddar.get_conj_key %ui : (!cheddar.user_interface) -> !cheddar.eval_key + return %key : !cheddar.eval_key +} + +// CHECK: @test_prepare_rot_key +func.func @test_prepare_rot_key(%ui: !cheddar.user_interface) { + // CHECK: cheddar.prepare_rot_key + // CHECK-SAME: distance = 3 + // CHECK-SAME: maxLevel = 10 + cheddar.prepare_rot_key %ui {distance = 3 : i64, maxLevel = 10 : i64} : (!cheddar.user_interface) -> () + return +} + +// --- Encode / Encrypt / Decrypt --- + +// CHECK: @test_encode +func.func @test_encode( + %enc: !cheddar.encoder, + %msg: tensor<4xf64>) -> !cheddar.plaintext { + // CHECK: cheddar.encode + // CHECK-SAME: level = 5 + // CHECK-SAME: scale = 45 + %pt = cheddar.encode %enc, %msg {level = 5 : i64, scale = 45 : i64} : (!cheddar.encoder, tensor<4xf64>) -> !cheddar.plaintext + return %pt : !cheddar.plaintext +} + +// CHECK: @test_encode_constant +func.func @test_encode_constant( + %enc: !cheddar.encoder, + %val: f64) -> !cheddar.constant { + // CHECK: cheddar.encode_constant + // CHECK-SAME: level = 3 + // CHECK-SAME: scale = 45 + %c = cheddar.encode_constant %enc, %val {level = 3 : i64, scale = 45 : i64} : (!cheddar.encoder, f64) -> !cheddar.constant + return %c : !cheddar.constant +} + +// CHECK: @test_decode +func.func @test_decode( + %enc: !cheddar.encoder, + %pt: !cheddar.plaintext) -> tensor<4xf64> { + // CHECK: cheddar.decode + %msg = cheddar.decode %enc, %pt : (!cheddar.encoder, !cheddar.plaintext) -> tensor<4xf64> + return %msg : tensor<4xf64> +} + +// CHECK: @test_encrypt +func.func @test_encrypt( + %ui: !cheddar.user_interface, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.encrypt + %ct = cheddar.encrypt %ui, %pt : (!cheddar.user_interface, !cheddar.plaintext) -> !cheddar.ciphertext + return %ct : !cheddar.ciphertext +} + +// CHECK: @test_decrypt +func.func @test_decrypt( + %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext) -> !cheddar.plaintext { + // CHECK: cheddar.decrypt + %pt = cheddar.decrypt %ui, %ct : (!cheddar.user_interface, !cheddar.ciphertext) -> !cheddar.plaintext + return %pt : !cheddar.plaintext +} + +// --- Binary ct-ct operations --- + +// CHECK: @test_add +func.func @test_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.add + %result = cheddar.add %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_sub +func.func @test_sub( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.sub + %result = cheddar.sub %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mult +func.func @test_mult( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.mult + %result = cheddar.mult %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Ct-pt / ct-const operations --- + +// CHECK: @test_add_plain +func.func @test_add_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.add_plain + %result = cheddar.add_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_sub_plain +func.func @test_sub_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.sub_plain + %result = cheddar.sub_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mult_plain +func.func @test_mult_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: cheddar.mult_plain + %result = cheddar.mult_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_add_const +func.func @test_add_const( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: cheddar.add_const + %result = cheddar.add_const %ctx, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mult_const +func.func @test_mult_const( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: cheddar.mult_const + %result = cheddar.mult_const %ctx, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Unary operations --- + +// CHECK: @test_neg +func.func @test_neg( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.neg + %result = cheddar.neg %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_rescale +func.func @test_rescale( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.rescale + %result = cheddar.rescale %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_level_down +func.func @test_level_down( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: cheddar.level_down + // CHECK-SAME: targetLevel = 3 + %result = cheddar.level_down %ctx, %ct {targetLevel = 3 : i64} : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Key-switching operations --- + +// CHECK: @test_relinearize +func.func @test_relinearize( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.relinearize + %result = cheddar.relinearize %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_relinearize_rescale +func.func @test_relinearize_rescale( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.relinearize_rescale + %result = cheddar.relinearize_rescale %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Fused compound operations --- + +// CHECK: @test_hmult +func.func @test_hmult( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hmult + %result = cheddar.hmult %ctx, %ct0, %ct1, %key {rescale = true} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hmult_no_rescale +func.func @test_hmult_no_rescale( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hmult + // CHECK-SAME: rescale = false + %result = cheddar.hmult %ctx, %ct0, %ct1, %key {rescale = false} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hrot +func.func @test_hrot( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hrot + // CHECK-SAME: static_shift = 5 + %result = cheddar.hrot %ctx, %ct, %key {static_shift = 5 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hrot_add +func.func @test_hrot_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hrot_add + // CHECK-SAME: distance = 3 + %result = cheddar.hrot_add %ctx, %ct0, %ct1, %key {distance = 3 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hconj +func.func @test_hconj( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hconj + %result = cheddar.hconj %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_hconj_add +func.func @test_hconj_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: cheddar.hconj_add + %result = cheddar.hconj_add %ctx, %ct0, %ct1, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_mad_unsafe +func.func @test_mad_unsafe( + %ctx: !cheddar.context, + %acc: !cheddar.ciphertext, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: cheddar.mad_unsafe + %result = cheddar.mad_unsafe %ctx, %acc, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Extension operations --- + +// CHECK: @test_boot +func.func @test_boot( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + // CHECK: cheddar.boot + %result = cheddar.boot %ctx, %ct, %evk : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_linear_transform +func.func @test_linear_transform( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map, + %diags: tensor<2x4xf64>) -> !cheddar.ciphertext { + // CHECK: cheddar.linear_transform + // CHECK-SAME: diagonal_indices = array + // CHECK-SAME: level = 5 + // CHECK-SAME: logBabyStepGiantStepRatio = 0 + %result = cheddar.linear_transform %ctx, %ct, %evk, %diags {diagonal_indices = array, level = 5 : i64, logBabyStepGiantStepRatio = 0 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map, tensor<2x4xf64>) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: @test_eval_poly +func.func @test_eval_poly( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %evk: !cheddar.evk_map) -> !cheddar.ciphertext { + // CHECK: cheddar.eval_poly + // CHECK-SAME: coefficients = [1.000000e+00, 2.000000e+00, 3.000000e+00] + %result = cheddar.eval_poly %ctx, %ct, %evk {coefficients = [1.0 : f64, 2.0 : f64, 3.0 : f64], level = 5 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.evk_map) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} diff --git a/tests/Dialect/Cheddar/Transforms/BUILD b/tests/Dialect/Cheddar/Transforms/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/Cheddar/Transforms/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/Cheddar/Transforms/fuse_ops.mlir b/tests/Dialect/Cheddar/Transforms/fuse_ops.mlir new file mode 100644 index 0000000000..0b7f518bd3 --- /dev/null +++ b/tests/Dialect/Cheddar/Transforms/fuse_ops.mlir @@ -0,0 +1,82 @@ +// RUN: heir-opt --cheddar-fuse-ops %s | FileCheck %s + +// Test: mult + relinearize + rescale -> hmult(rescale=true) +// CHECK: @test_fuse_hmult_with_rescale +func.func @test_fuse_hmult_with_rescale( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK-NOT: cheddar.mult + // CHECK-NOT: cheddar.relinearize + // CHECK-NOT: cheddar.rescale + // CHECK: cheddar.hmult + // rescale=true is the default and gets elided from the attr-dict + // CHECK-NOT: rescale = false + %mult = cheddar.mult %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %relin = cheddar.relinearize %ctx, %mult, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + %rescaled = cheddar.rescale %ctx, %relin : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %rescaled : !cheddar.ciphertext +} + +// Test: mult + relinearize -> hmult(rescale=false) +// CHECK: @test_fuse_hmult_no_rescale +func.func @test_fuse_hmult_no_rescale( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK-NOT: cheddar.mult + // CHECK-NOT: cheddar.relinearize + // CHECK: cheddar.hmult + // CHECK-SAME: rescale = false + %mult = cheddar.mult %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %relin = cheddar.relinearize %ctx, %mult, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %relin : !cheddar.ciphertext +} + +// Test: hrot + add -> hrot_add +// CHECK: @test_fuse_hrot_add +func.func @test_fuse_hrot_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK-NOT: cheddar.hrot + // CHECK-NOT: cheddar.add + // CHECK: cheddar.hrot_add + %rotated = cheddar.hrot %ctx, %ct0, %key {static_shift = 3 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + %sum = cheddar.add %ctx, %rotated, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %sum : !cheddar.ciphertext +} + +// Test: hconj + add -> hconj_add +// CHECK: @test_fuse_hconj_add +func.func @test_fuse_hconj_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK-NOT: cheddar.hconj + // CHECK-NOT: cheddar.add + // CHECK: cheddar.hconj_add + %conjugated = cheddar.hconj %ctx, %ct0, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + %sum = cheddar.add %ctx, %conjugated, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %sum : !cheddar.ciphertext +} + +// Test: mult + relinearize_rescale -> hmult(rescale=true) +// CHECK: @test_fuse_hmult_with_relin_rescale +func.func @test_fuse_hmult_with_relin_rescale( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK-NOT: cheddar.mult + // CHECK-NOT: cheddar.relinearize_rescale + // CHECK: cheddar.hmult + // CHECK-NOT: rescale = false + %mult = cheddar.mult %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + %relin_rescaled = cheddar.relinearize_rescale %ctx, %mult, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %relin_rescaled : !cheddar.ciphertext +} diff --git a/tests/Dialect/LWE/Conversions/lwe_to_cheddar/BUILD b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/LWE/Conversions/lwe_to_cheddar/func_call.mlir b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/func_call.mlir new file mode 100644 index 0000000000..a7e1a88da3 --- /dev/null +++ b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/func_call.mlir @@ -0,0 +1,47 @@ +// RUN: heir-opt --lwe-to-cheddar %s | FileCheck %s + +!Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> +!Z36028797018652673_i64 = !mod_arith.int<36028797018652673 : i64> +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +#layout = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : i0 = 0 and ct = 0 and (-i1 + slot) mod 512 = 0 and 0 <= i1 <= 511 and 0 <= slot <= 1023 }"> +#layout1 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : (i0 - i1 + ct) mod 512 = 0 and (-i1 + ct + slot) mod 1024 = 0 and 0 <= i0 <= 511 and 0 <= i1 <= 783 and 0 <= ct <= 511 and 0 <= slot <= 1023 }"> +#layout2 = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : i0 = 0 and ct = 0 and (-i1 + slot) mod 1024 = 0 and 0 <= i1 <= 783 and 0 <= slot <= 1023 }"> +#modulus_chain_L1_C1 = #lwe.modulus_chain, current = 1> +#ring_f64_1_x1024 = #polynomial.ring> +!rns_L1 = !rns.rns +#original_type = #tensor_ext.original_type, layout = #layout> +!pt = !lwe.lwe_plaintext> +#ring_rns_L1_1_x1024 = #polynomial.ring> +#ciphertext_space_L1 = #lwe.ciphertext_space +!ct_L1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> + +// CHECK-DAG: !ctx = !cheddar.context +// CHECK-DAG: !encoder = !cheddar.encoder +// CHECK-DAG: !ui = !cheddar.user_interface +// CHECK-DAG: !ct = !cheddar.ciphertext +module @jit_func attributes {backend.cheddar, ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { +// CHECK: func.func @mnist__preprocessed(%ctx: !ctx, %encoder: !encoder, %ui: !ui +// CHECK: %arg1: tensor<1x!ct> +// CHECK: func.func public @mnist(%ctx: !ctx, %encoder: !encoder, %ui: !ui +// CHECK: %arg1: tensor<1x!ct> +// CHECK: %0 = call @_assign_layout_11979326689855340354(%arg0) +// CHECK: %1 = call @mnist__preprocessed(%ctx, %encoder, %ui, %0, %arg1) : (!ctx, !encoder, !ui, tensor<512x1024xf32>, tensor<1x!ct>) -> tensor<1x!ct> + func.func private @_assign_layout_11979326689855340354(tensor<512x784xf32>) -> tensor<512x1024xf32> attributes {client.pack_func = {func_name = "mnist"}} + + func.func @mnist__preprocessed(%arg0: tensor<512x1024xf32> {tensor_ext.original_type = #tensor_ext.original_type, layout = #layout1>}, %arg1: tensor<1x!ct_L1> {tensor_ext.original_type = #tensor_ext.original_type, layout = #layout2>}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) attributes {client.preprocessed_func = {func_name = "mnist"}} { + %c0 = arith.constant 0 : index + %extracted_slice = tensor.extract_slice %arg0[0, 0] [1, 1024] [1, 1] : tensor<512x1024xf32> to tensor<1024xf32> + %pt = lwe.rlwe_encode %extracted_slice {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %extracted = tensor.extract %arg1[%c0] : tensor<1x!ct_L1> + %ct = lwe.rmul_plain %extracted, %pt : (!ct_L1, !pt) -> !ct_L1 + %from_elements = tensor.from_elements %ct : tensor<1x!ct_L1> + return %from_elements : tensor<1x!ct_L1> + } + + func.func public @mnist(%arg0: tensor<512x784xf32>, %arg1: tensor<1x!ct_L1> {tensor_ext.original_type = #tensor_ext.original_type, layout = #layout2>}) -> (tensor<1x!ct_L1> {jax.result_info = "result[0]", tensor_ext.original_type = #original_type}) { + %0 = call @_assign_layout_11979326689855340354(%arg0) : (tensor<512x784xf32>) -> tensor<512x1024xf32> + %1 = call @mnist__preprocessed(%0, %arg1) {arg_attrs = [{mhlo.sharding = "{replicated}", tensor_ext.layout = #layout1}, {tensor_ext.layout = #layout2}]} : (tensor<512x1024xf32>, tensor<1x!ct_L1>) -> tensor<1x!ct_L1> + return %1 : tensor<1x!ct_L1> + } +} diff --git a/tests/Dialect/LWE/Conversions/lwe_to_cheddar/lwe_to_cheddar.mlir b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/lwe_to_cheddar.mlir new file mode 100644 index 0000000000..df1aa05fc0 --- /dev/null +++ b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/lwe_to_cheddar.mlir @@ -0,0 +1,131 @@ +// RUN: heir-opt --lwe-to-cheddar %s | FileCheck %s + +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +!Z36028797018652673_i64 = !mod_arith.int<36028797018652673 : i64> +!Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> +!rns_L1 = !rns.rns +!rns_L0 = !rns.rns +#ring_f64_1_x1024 = #polynomial.ring> +#ring_rns_L1_1_x1024 = #polynomial.ring> +#ring_rns_L0_1_x1024 = #polynomial.ring> +#ciphertext_space_L1 = #lwe.ciphertext_space +#ciphertext_space_L1_D3 = #lwe.ciphertext_space +#ciphertext_space_L0 = #lwe.ciphertext_space +#modulus_chain_L1_C1 = #lwe.modulus_chain, current = 1> +#modulus_chain_L1_C0 = #lwe.modulus_chain, current = 0> +#plaintext_space = #lwe.plaintext_space +!pt = !lwe.lwe_plaintext +!ct_L1 = !lwe.lwe_ciphertext +!ct_L1_D3 = !lwe.lwe_ciphertext +!ct_L0 = !lwe.lwe_ciphertext + +#inverse_canonical_encoding_lo = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding_hi = #lwe.inverse_canonical_encoding +#plaintext_space_lo = #lwe.plaintext_space +#plaintext_space_hi = #lwe.plaintext_space +!ct_L1_exact = !lwe.lwe_ciphertext +!ct_L1_exact_1 = !lwe.lwe_ciphertext +!ct_L0_exact_1 = !lwe.lwe_ciphertext + +// Verify the pass threads cheddar context, encoder, and UI args. +// CHECK-DAG: ![[CTX_T:.*]] = !cheddar.context +// CHECK-DAG: ![[ENC_T:.*]] = !cheddar.encoder +// CHECK-DAG: ![[UI_T:.*]] = !cheddar.user_interface +// CHECK-DAG: ![[CT_T:.*]] = !cheddar.ciphertext +// CHECK: func.func @test_add(%[[CTX:.*]]: ![[CTX_T]], %[[ENC:.*]]: ![[ENC_T]], %[[UI:.*]]: ![[UI_T]], %[[CT0:.*]]: ![[CT_T]], %[[CT1:.*]]: ![[CT_T]]) +module attributes {ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { + func.func @test_add(%ct0: !ct_L1, %ct1: !ct_L1) -> !ct_L1 { + // CHECK: cheddar.add %[[CTX]], %[[CT0]], %[[CT1]] + %result = lwe.radd %ct0, %ct1 : (!ct_L1, !ct_L1) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_sub + func.func @test_sub(%ct0: !ct_L1, %ct1: !ct_L1) -> !ct_L1 { + // CHECK: cheddar.sub + %result = lwe.rsub %ct0, %ct1 : (!ct_L1, !ct_L1) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_negate + func.func @test_negate(%ct: !ct_L1) -> !ct_L1 { + // CHECK: cheddar.neg + %result = ckks.negate %ct : !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_relin + func.func @test_relin(%ct: !ct_L1_D3) -> !ct_L1 { + // CHECK: cheddar.get_mult_key + // CHECK: cheddar.relinearize + %result = ckks.relinearize %ct {from_basis = array, to_basis = array} : (!ct_L1_D3) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_rotate + func.func @test_rotate(%ct: !ct_L1) -> !ct_L1 { + %c5 = arith.constant 5 : i32 + // CHECK: cheddar.get_rot_key + // CHECK: cheddar.hrot + %result = ckks.rotate %ct, %c5 : i32 : !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_add_plain + func.func @test_add_plain(%ct: !ct_L1, %pt: !pt) -> !ct_L1 { + // CHECK: cheddar.add_plain + %result = lwe.radd_plain %ct, %pt : (!ct_L1, !pt) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_sub_plain_ct_first + func.func @test_sub_plain_ct_first(%ct: !ct_L1, %pt: !pt) -> !ct_L1 { + // CHECK: cheddar.sub_plain + %result = lwe.rsub_plain %ct, %pt : (!ct_L1, !pt) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_sub_plain_pt_first + func.func @test_sub_plain_pt_first(%pt: !pt, %ct: !ct_L1) -> !ct_L1 { + // CHECK: cheddar.neg + // CHECK: cheddar.add_plain + %result = lwe.rsub_plain %pt, %ct : (!pt, !ct_L1) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_encode + func.func @test_encode(%ct: !ct_L1) -> !ct_L1 { + %cst = arith.constant dense<1.0> : tensor<1024xf64> + // CHECK: cheddar.encode + %pt = lwe.rlwe_encode %cst {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf64> -> !pt + // CHECK: cheddar.add_plain + %result = lwe.radd_plain %ct, %pt : (!ct_L1, !pt) -> !ct_L1 + return %result : !ct_L1 + } + + // CHECK: func.func @test_encrypt + func.func @test_encrypt(%pt: !pt, %pk: !lwe.lwe_public_key) -> !ct_L1 { + // CHECK: cheddar.encrypt + %ct = lwe.rlwe_encrypt %pt, %pk : (!pt, !lwe.lwe_public_key) -> !ct_L1 + return %ct : !ct_L1 + } + + // CHECK: func.func @test_decrypt + func.func @test_decrypt(%ct: !ct_L1, %sk: !lwe.lwe_secret_key) -> !pt { + // CHECK: cheddar.decrypt + %pt = lwe.rlwe_decrypt %ct, %sk : (!ct_L1, !lwe.lwe_secret_key) -> !pt + return %pt : !pt + } + + // CHECK: func.func @test_level_reduce_scaled + // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[RED:.*]] = cheddar.rescale %[[CTX]], %{{.*}} + // CHECK: %[[CONST:.*]] = cheddar.encode_constant %[[ENC]], %[[ONE]] {level = 0 : i64, scale = 45 : i64} + // CHECK: cheddar.mult_const %[[CTX]], %[[RED]], %[[CONST]] + func.func @test_level_reduce_scaled(%ct: !ct_L1_exact_1) -> !ct_L0_exact_1 { + %result = ckks.level_reduce %ct {levelToDrop = 1 : i64} : !ct_L1_exact_1 -> !ct_L0_exact_1 + return %result : !ct_L0_exact_1 + } + +} diff --git a/tests/Dialect/LWE/Conversions/lwe_to_cheddar/unsupported_encode_encoding.mlir b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/unsupported_encode_encoding.mlir new file mode 100644 index 0000000000..bd9d557c59 --- /dev/null +++ b/tests/Dialect/LWE/Conversions/lwe_to_cheddar/unsupported_encode_encoding.mlir @@ -0,0 +1,24 @@ +// RUN: not heir-opt --lwe-to-cheddar --verify-diagnostics %s + +// CHECK: error: 'lwe.rlwe_encode' op requires inverse-canonical CKKS plaintext encoding for CHEDDAR lowering + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> +!Z65537_i64 = !mod_arith.int<65537 : i64> +!rns_L0 = !rns.rns +#ring_Z65537_i64_1_x32 = #polynomial.ring> +#ring_rns_L0_1_x32 = #polynomial.ring> +#ciphertext_space_L0 = #lwe.ciphertext_space +#modulus_chain_L0_C0 = #lwe.modulus_chain, current = 0> +#plaintext_space = #lwe.plaintext_space +!pt = !lwe.lwe_plaintext +!ct_L0 = !lwe.lwe_ciphertext + +module attributes {scheme.ckks, ckks.schemeParam = #ckks.scheme_param} { + func.func @bad_encode(%ct: !ct_L0) -> !ct_L0 { + %cst = arith.constant dense<1> : tensor<32xi16> + %pt = lwe.rlwe_encode %cst {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32} : tensor<32xi16> -> !pt + %result = lwe.radd_plain %ct, %pt : (!ct_L0, !pt) -> !ct_L0 + return %result : !ct_L0 + } +} diff --git a/tests/Emitter/Cheddar/BUILD b/tests/Emitter/Cheddar/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Emitter/Cheddar/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Emitter/Cheddar/emit_cheddar.mlir b/tests/Emitter/Cheddar/emit_cheddar.mlir new file mode 100644 index 0000000000..923e3a6805 --- /dev/null +++ b/tests/Emitter/Cheddar/emit_cheddar.mlir @@ -0,0 +1,326 @@ +// RUN: heir-translate --emit-cheddar %s | FileCheck %s + +// CHECK: #include "core/Context.h" +// CHECK: using namespace cheddar; +// CHECK: using word = uint64_t; + +// --- Binary ct-ct operations --- + +// CHECK: Ct test_add( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]]) +func.func @test_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: if (std::getenv("HEIR_CHEDDAR_DEBUG_SCALES")) { + // CHECK: double lhs_scale = [[CT0]].GetScale(); + // CHECK-NEXT: double rhs_scale = [[CT1]].GetScale(); + // CHECK: [[CTX]]->Add([[RES]], [[CT0]], [[CT1]]); + %result = cheddar.add %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_sub( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]]) +func.func @test_sub( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: if (std::getenv("HEIR_CHEDDAR_DEBUG_SCALES")) { + // CHECK: double lhs_scale = [[CT0]].GetScale(); + // CHECK-NEXT: double rhs_scale = [[CT1]].GetScale(); + // CHECK: [[CTX]]->Sub([[RES]], [[CT0]], [[CT1]]); + %result = cheddar.sub %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_mult( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]]) +func.func @test_mult( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Mult([[RES]], [[CT0]], [[CT1]]); + %result = cheddar.mult %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: int32_t test_floor_div_mul( +// CHECK-SAME: int32_t [[A:.*]], int32_t [[B:.*]]) +func.func @test_floor_div_mul(%a: i32, %b: i32) -> i32 { + // CHECK: int32_t [[Q:.*]] = ([[A]] / [[B]]) - (([[A]] % [[B]] != 0) && (([[A]] < 0) != ([[B]] < 0))); + %q = arith.floordivsi %a, %b : i32 + // CHECK: int32_t [[M:.*]] = [[Q]] * [[B]]; + %m = arith.muli %q, %b : i32 + return %m : i32 +} + +// --- Ct-pt / ct-const operations --- + +// CHECK: Ct test_add_plain( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], Pt& [[PT:.*]]) +func.func @test_add_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[PT]].SetScale([[CT]].GetScale()); + // CHECK-NEXT: [[CTX]]->Add([[RES]], [[CT]], [[PT]]); + %result = cheddar.add_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_sub_plain( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], Pt& [[PT:.*]]) +func.func @test_sub_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: if (std::getenv("HEIR_CHEDDAR_DEBUG_SCALES")) { + // CHECK: double lhs_scale = [[CT]].GetScale(); + // CHECK-NEXT: double rhs_scale = [[PT]].GetScale(); + // CHECK: [[CTX]]->Sub([[RES]], [[CT]], [[PT]]); + %result = cheddar.sub_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_mult_plain( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], Pt& [[PT:.*]]) +func.func @test_mult_plain( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Mult([[RES]], [[CT]], [[PT]]); + %result = cheddar.mult_plain %ctx, %ct, %pt : (!cheddar.context, !cheddar.ciphertext, !cheddar.plaintext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_add_const( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], const Const& [[C:.*]]) +func.func @test_add_const( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Add([[RES]], [[CT]], [[C]]); + %result = cheddar.add_const %ctx, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_mult_const( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], const Const& [[C:.*]]) +func.func @test_mult_const( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Mult([[RES]], [[CT]], [[C]]); + %result = cheddar.mult_const %ctx, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Unary operations --- + +// CHECK: Ct test_neg( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]]) +func.func @test_neg( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Neg([[RES]], [[CT]]); + %result = cheddar.neg %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_rescale( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]]) +func.func @test_rescale( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Rescale([[RES]], [[CT]]); + %result = cheddar.rescale %ctx, %ct : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_level_down( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]]) +func.func @test_level_down( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->LevelDown([[RES]], [[CT]], 3); + %result = cheddar.level_down %ctx, %ct {targetLevel = 3 : i64} : (!cheddar.context, !cheddar.ciphertext) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Key-switching operations --- + +// CHECK: Ct test_relinearize( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], const Evk& [[KEY:.*]]) +func.func @test_relinearize( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Relinearize([[RES]], [[CT]], [[KEY]]); + %result = cheddar.relinearize %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_relinearize_rescale( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], const Evk& [[KEY:.*]]) +func.func @test_relinearize_rescale( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->RelinearizeRescale([[RES]], [[CT]], [[KEY]]); + %result = cheddar.relinearize_rescale %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Fused compound operations --- + +// CHECK: Ct test_hmult( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]], const Evk& [[KEY:.*]]) +func.func @test_hmult( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->HMult([[RES]], [[CT0]], [[CT1]], [[KEY]], true); + %result = cheddar.hmult %ctx, %ct0, %ct1, %key {rescale = true} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_hmult_no_rescale( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]], const Evk& [[KEY:.*]]) +func.func @test_hmult_no_rescale( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->HMult([[RES]], [[CT0]], [[CT1]], [[KEY]], false); + %result = cheddar.hmult %ctx, %ct0, %ct1, %key {rescale = false} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_hrot( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], const Evk& [[KEY:.*]]) +func.func @test_hrot( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->HRot([[RES]], [[CT]], [[KEY]], 5); + %result = cheddar.hrot %ctx, %ct, %key {static_shift = 5 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_hrot_add( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]], const Evk& [[KEY:.*]]) +func.func @test_hrot_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->HRotAdd([[RES]], [[CT0]], [[CT1]], [[KEY]], 3); + %result = cheddar.hrot_add %ctx, %ct0, %ct1, %key {distance = 3 : i64} : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_hconj( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT:.*]], const Evk& [[KEY:.*]]) +func.func @test_hconj( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->HConj([[RES]], [[CT]], [[KEY]]); + %result = cheddar.hconj %ctx, %ct, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_hconj_add( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[CT0:.*]], const Ct& [[CT1:.*]], const Evk& [[KEY:.*]]) +func.func @test_hconj_add( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext, + %key: !cheddar.eval_key) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->HConjAdd([[RES]], [[CT0]], [[CT1]], [[KEY]]); + %result = cheddar.hconj_add %ctx, %ct0, %ct1, %key : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.eval_key) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// CHECK: Ct test_mad_unsafe( +// CHECK-SAME: CtxPtr [[CTX:.*]], const Ct& [[ACC:.*]], const Ct& [[CT:.*]], const Const& [[C:.*]]) +func.func @test_mad_unsafe( + %ctx: !cheddar.context, + %acc: !cheddar.ciphertext, + %ct: !cheddar.ciphertext, + %c: !cheddar.constant) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[CTX]]->Copy([[RES]], [[ACC]]); + // CHECK-NEXT: [[CTX]]->MadUnsafe([[RES]], [[CT]], [[C]]); + %result = cheddar.mad_unsafe %ctx, %acc, %ct, %c : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext, !cheddar.constant) -> !cheddar.ciphertext + return %result : !cheddar.ciphertext +} + +// --- Setup operations --- + +// CHECK: void test_setup( +// CHECK-SAME: const Param& [[PARAMS:.*]]) +func.func @test_setup(%params: !cheddar.parameter) { + // CHECK: auto [[CTX:.*]] = Context::Create([[PARAMS]]); + %ctx = cheddar.create_context %params : (!cheddar.parameter) -> !cheddar.context + // CHECK: UI [[UI:.*]]([[CTX]]); + %ui = cheddar.create_user_interface %ctx : (!cheddar.context) -> !cheddar.user_interface + // CHECK: auto& [[ENC:.*]] = [[CTX]]->encoder_; + %enc = cheddar.get_encoder %ctx : (!cheddar.context) -> !cheddar.encoder + // CHECK: const auto& [[EVK:.*]] = [[UI]].GetEvkMap(); + %evk = cheddar.get_evk_map %ui : (!cheddar.user_interface) -> !cheddar.evk_map + // CHECK: const auto& [[MKEY:.*]] = [[UI]].GetMultiplicationKey(); + %mkey = cheddar.get_mult_key %ui : (!cheddar.user_interface) -> !cheddar.eval_key + // CHECK: const auto& [[RKEY:.*]] = [[UI]].GetRotationKey(7); + %rkey = cheddar.get_rot_key %ui {distance = 7 : i64} : (!cheddar.user_interface) -> !cheddar.eval_key + // CHECK: const auto& [[CKEY:.*]] = [[UI]].GetConjugationKey(); + %ckey = cheddar.get_conj_key %ui : (!cheddar.user_interface) -> !cheddar.eval_key + // CHECK: [[UI]].PrepareRotationKey(3, 10); + cheddar.prepare_rot_key %ui {distance = 3 : i64, maxLevel = 10 : i64} : (!cheddar.user_interface) -> () + return +} + +// --- Encrypt / Decrypt --- + +// CHECK: Ct test_encrypt( +// CHECK-SAME: UI& [[UI:.*]], Pt& [[PT:.*]]) +func.func @test_encrypt( + %ui: !cheddar.user_interface, + %pt: !cheddar.plaintext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK-NEXT: [[UI]].Encrypt([[RES]], [[PT]]); + %ct = cheddar.encrypt %ui, %pt : (!cheddar.user_interface, !cheddar.plaintext) -> !cheddar.ciphertext + return %ct : !cheddar.ciphertext +} + +// CHECK: Pt test_decrypt( +// CHECK-SAME: UI& [[UI:.*]], const Ct& [[CT:.*]]) +func.func @test_decrypt( + %ui: !cheddar.user_interface, + %ct: !cheddar.ciphertext) -> !cheddar.plaintext { + // CHECK: Pt [[RES:.*]]; + // CHECK-NEXT: [[UI]].Decrypt([[RES]], [[CT]]); + %pt = cheddar.decrypt %ui, %ct : (!cheddar.user_interface, !cheddar.ciphertext) -> !cheddar.plaintext + return %pt : !cheddar.plaintext +} diff --git a/tests/Emitter/Cheddar/emit_cheddar_control_flow.mlir b/tests/Emitter/Cheddar/emit_cheddar_control_flow.mlir new file mode 100644 index 0000000000..5da1031ac4 --- /dev/null +++ b/tests/Emitter/Cheddar/emit_cheddar_control_flow.mlir @@ -0,0 +1,79 @@ +// RUN: heir-translate --emit-cheddar %s | FileCheck %s + +// Test scf.for emission +// CHECK: Ct test_for_loop +func.func @test_for_loop( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + // CHECK: for (int64_t {{.*}} = 0; {{.*}} < 4; {{.*}} += 1) { + %result = scf.for %i = %c0 to %c4 step %c1 iter_args(%acc = %ct0) -> !cheddar.ciphertext { + // CHECK: {{.*}}->Add( + %sum = cheddar.add %ctx, %acc, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + scf.yield %sum : !cheddar.ciphertext + } + // CHECK: } + return %result : !cheddar.ciphertext +} + +// Test scf.if emission +// CHECK: Ct test_if +func.func @test_if( + %ctx: !cheddar.context, + %cond: i1, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> !cheddar.ciphertext { + // CHECK: Ct [[RES:.*]]; + // CHECK: if ({{.*}}) { + %result = scf.if %cond -> !cheddar.ciphertext { + // CHECK: {{.*}}->Add( + %sum = cheddar.add %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + scf.yield %sum : !cheddar.ciphertext + } else { + // CHECK: } else { + // CHECK: {{.*}}->Sub( + %diff = cheddar.sub %ctx, %ct0, %ct1 : (!cheddar.context, !cheddar.ciphertext, !cheddar.ciphertext) -> !cheddar.ciphertext + scf.yield %diff : !cheddar.ciphertext + } + // CHECK: } + return %result : !cheddar.ciphertext +} + +// Test that numeric tensor updates carried through SCF loops reuse the moved +// vector instead of deep-copying it on every tensor.insert. +// CHECK: std::vector test_tensor_insert_move +func.func @test_tensor_insert_move() -> tensor<4xf32> { + %cst = arith.constant dense<0.0> : tensor<4xf32> + %c1 = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1_idx = arith.constant 1 : index + %result = scf.for %i = %c0 to %c4 step %c1_idx iter_args(%acc = %cst) -> tensor<4xf32> { + // CHECK: auto [[INS:.*]] = std::move([[ACC:.*]]); + // CHECK-NEXT: [[INS]][{{.*}}] = {{.*}}; + %inserted = tensor.insert %c1 into %acc[%i] : tensor<4xf32> + scf.yield %inserted : tensor<4xf32> + } + return %result : tensor<4xf32> +} + +// Test that numeric tensor.insert_slice also reuses the moved vector when the +// destination is a loop-carried local. +// CHECK: std::vector test_tensor_insert_slice_move +func.func @test_tensor_insert_slice_move() -> tensor<4xf32> { + %dest = arith.constant dense<0.0> : tensor<4xf32> + %src = arith.constant dense<1.0> : tensor<2xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %result = scf.for %i = %c0 to %c1 step %c1 iter_args(%acc = %dest) -> tensor<4xf32> { + // CHECK: auto [[INS_SLICE:.*]] = std::move([[ACC_SLICE:.*]]); + // CHECK-NEXT: std::copy({{.*}}.begin(), {{.*}}.end(), [[INS_SLICE]].begin() + 1); + %inserted = tensor.insert_slice %src into %acc[1] [2] [1] : tensor<2xf32> into tensor<4xf32> + scf.yield %inserted : tensor<4xf32> + } + return %result : tensor<4xf32> +} diff --git a/tests/Emitter/Cheddar/emit_cheddar_tensor.mlir b/tests/Emitter/Cheddar/emit_cheddar_tensor.mlir new file mode 100644 index 0000000000..15fa1268ea --- /dev/null +++ b/tests/Emitter/Cheddar/emit_cheddar_tensor.mlir @@ -0,0 +1,117 @@ +// RUN: heir-translate --emit-cheddar %s | FileCheck %s + +// Test tensor.empty +// CHECK: void test_tensor_empty +func.func @test_tensor_empty(%ctx: !cheddar.context) { + // CHECK: std::vector [[VEC:.*]]; + // CHECK-NEXT: [[VEC]].resize(4); + %empty = tensor.empty() : tensor<4x!cheddar.ciphertext> + return +} + +// Test dense splat constants lower to sized vector constructors instead of +// enormous initializer lists. +// CHECK: {{.*}} test_dense_splats +func.func @test_dense_splats() { + // CHECK: std::vector [[FLOATS:.*]](8, 0{{(\.0+)?}}); + %floats = arith.constant dense<0.0> : tensor<8xf32> + // CHECK: std::vector [[INTS:.*]](6, 7); + %ints = arith.constant dense<7> : tensor<6xindex> + return +} + +// Test tensor.from_elements with ciphertexts (move-only) +// CHECK: {{.*}} test_from_elements +func.func @test_from_elements( + %ctx: !cheddar.context, + %ct0: !cheddar.ciphertext, + %ct1: !cheddar.ciphertext) -> tensor<2x!cheddar.ciphertext> { + // CHECK: std::vector [[VEC:.*]]; + // CHECK-NEXT: [[VEC]].reserve(2); + // CHECK: Copy( + // CHECK: [[VEC]].emplace_back(std::move( + // CHECK: Copy( + // CHECK: [[VEC]].emplace_back(std::move( + %t = tensor.from_elements %ct0, %ct1 : tensor<2x!cheddar.ciphertext> + return %t : tensor<2x!cheddar.ciphertext> +} + +// Test tensor.extract with ciphertexts (reference) +// CHECK: {{.*}} test_extract +func.func @test_extract( + %ctx: !cheddar.context, + %t: tensor<4x!cheddar.ciphertext>) -> !cheddar.ciphertext { + %c1 = arith.constant 1 : index + // CHECK: auto& [[RES:.*]] = {{.*}}[{{.*}}]; + %elem = tensor.extract %t[%c1] : tensor<4x!cheddar.ciphertext> + return %elem : !cheddar.ciphertext +} + +// Test tensor.insert with ciphertexts (in-place move) +// CHECK: {{.*}} test_insert +func.func @test_insert( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext, + %t: tensor<4x!cheddar.ciphertext>) -> tensor<4x!cheddar.ciphertext> { + %c2 = arith.constant 2 : index + // CHECK: Copy( + // CHECK: {{.*}}[{{.*}}] = std::move( + %result = tensor.insert %ct into %t[%c2] : tensor<4x!cheddar.ciphertext> + return %result : tensor<4x!cheddar.ciphertext> +} + +// Test tensor.insert into tensor.empty with ciphertexts. This should not deep +// copy the uninitialized empty tensor elements. +// CHECK: {{.*}} test_insert_into_empty +func.func @test_insert_into_empty( + %ctx: !cheddar.context, + %ct: !cheddar.ciphertext) -> tensor<1x!cheddar.ciphertext> { + %empty = tensor.empty() : tensor<1x!cheddar.ciphertext> + %c0 = arith.constant 0 : index + // CHECK: std::vector [[EMPTY:.*]]; + // CHECK-NEXT: [[EMPTY]].resize(1); + // CHECK: std::vector [[RES:.*]]; + // CHECK-NEXT: [[RES]].resize(1); + // CHECK-NOT: [[EMPTY]][i] + // CHECK: Copy( + // CHECK: [[RES]][{{.*}}] = std::move( + %result = tensor.insert %ct into %empty[%c0] : tensor<1x!cheddar.ciphertext> + return %result : tensor<1x!cheddar.ciphertext> +} + +// Test tensor.extract_slice +// CHECK: {{.*}} test_extract_slice +func.func @test_extract_slice( + %ctx: !cheddar.context, + %t: tensor<8x!cheddar.ciphertext>) -> tensor<4x!cheddar.ciphertext> { + // CHECK: std::vector {{.*}}({{.*}}.begin() + 2, {{.*}}.begin() + 2 + 4); + %result = tensor.extract_slice %t[2] [4] [1] : tensor<8x!cheddar.ciphertext> to tensor<4x!cheddar.ciphertext> + return %result : tensor<4x!cheddar.ciphertext> +} + +// Numeric tensor.insert into tensor.empty should allocate a fresh zeroed result +// instead of copying the empty destination vector. +// CHECK: {{.*}} test_insert_numeric_into_empty +func.func @test_insert_numeric_into_empty(%x: f32) -> tensor<4xf32> { + %empty = tensor.empty() : tensor<4xf32> + %c1 = arith.constant 1 : index + // CHECK: std::vector [[EMPTY:.*]](4); + // CHECK: std::vector [[RES:.*]](4); + // CHECK-NOT: [[RES]] = [[EMPTY]] + // CHECK: [[RES]][{{.*}}] = {{.*}}; + %result = tensor.insert %x into %empty[%c1] : tensor<4xf32> + return %result : tensor<4xf32> +} + +// Numeric tensor.insert_slice into tensor.empty should also allocate a fresh +// zeroed result instead of copying the empty destination. +// CHECK: {{.*}} test_insert_slice_numeric_into_empty +func.func @test_insert_slice_numeric_into_empty(%src: tensor<2xf32>) -> tensor<4xf32> { + %empty = tensor.empty() : tensor<4xf32> + // CHECK: std::vector [[EMPTY:.*]](4); + // CHECK: std::vector [[RES:.*]](4); + // CHECK-NOT: [[RES]] = [[EMPTY]] + // CHECK: std::copy({{.*}}.begin(), {{.*}}.end(), [[RES]].begin() + 1); + %result = tensor.insert_slice %src into %empty[1] [2] [1] : tensor<2xf32> into tensor<4xf32> + return %result : tensor<4xf32> +} diff --git a/tests/Examples/common/mult_dep_8f.mlir b/tests/Examples/common/mult_dep_8f.mlir new file mode 100644 index 0000000000..c302ee4bd0 --- /dev/null +++ b/tests/Examples/common/mult_dep_8f.mlir @@ -0,0 +1,12 @@ +func.func @mult_dep( + %arg0: f32 {secret.secret} + ) -> f32 { + %0 = arith.mulf %arg0, %arg0 : f32 + %1 = arith.mulf %0, %arg0 : f32 + %2 = arith.mulf %1, %arg0 : f32 + %3 = arith.mulf %2, %arg0 : f32 + %4 = arith.mulf %3, %arg0 : f32 + %5 = arith.mulf %4, %arg0 : f32 + %6 = arith.mulf %5, %arg0 : f32 + return %6 : f32 +} diff --git a/tests/Examples/common/mult_indep_8f.mlir b/tests/Examples/common/mult_indep_8f.mlir new file mode 100644 index 0000000000..46f44fd8dc --- /dev/null +++ b/tests/Examples/common/mult_indep_8f.mlir @@ -0,0 +1,19 @@ +func.func @mult_indep( + %arg0: f32 {secret.secret}, + %arg1: f32 {secret.secret}, + %arg2: f32 {secret.secret}, + %arg3: f32 {secret.secret}, + %arg4: f32 {secret.secret}, + %arg5: f32 {secret.secret}, + %arg6: f32 {secret.secret}, + %arg7: f32 {secret.secret} + ) -> f32 { + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.mulf %0, %arg2 : f32 + %2 = arith.mulf %1, %arg3 : f32 + %3 = arith.mulf %2, %arg4 : f32 + %4 = arith.mulf %3, %arg5 : f32 + %5 = arith.mulf %4, %arg6 : f32 + %6 = arith.mulf %5, %arg7 : f32 + return %6 : f32 +} diff --git a/tests/Examples/common/simple_sumf.mlir b/tests/Examples/common/simple_sumf.mlir new file mode 100644 index 0000000000..f588c53c79 --- /dev/null +++ b/tests/Examples/common/simple_sumf.mlir @@ -0,0 +1,10 @@ +func.func @simple_sum(%arg0: tensor<32xf32> {secret.secret}) -> f32 { + %c0 = arith.constant 0 : index + %c0_f32 = arith.constant 0.0 : f32 + %0 = affine.for %i = 0 to 32 iter_args(%sum_iter = %c0_f32) -> f32 { + %1 = tensor.extract %arg0[%i] : tensor<32xf32> + %2 = arith.addf %1, %sum_iter : f32 + affine.yield %2 : f32 + } + return %0 : f32 +} diff --git a/tests/Transforms/annotate_module/cheddar.mlir b/tests/Transforms/annotate_module/cheddar.mlir new file mode 100644 index 0000000000..12dba0109d --- /dev/null +++ b/tests/Transforms/annotate_module/cheddar.mlir @@ -0,0 +1,6 @@ +// RUN: heir-opt --annotate-module="backend=cheddar scheme=ckks" %s | FileCheck %s + +// CHECK: module attributes {backend.cheddar, scheme.ckks} +module { + +} diff --git a/tools/BUILD b/tools/BUILD index fadc25dd20..ede783a0fc 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -58,11 +58,15 @@ cc_binary( "@heir//lib/Dialect/CGGI/Transforms", "@heir//lib/Dialect/CKKS/IR:Dialect", "@heir//lib/Dialect/CKKS/Transforms", + "@heir//lib/Dialect/Cheddar/IR:Dialect", + "@heir//lib/Dialect/Cheddar/Transforms:ConfigureCryptoContext", + "@heir//lib/Dialect/Cheddar/Transforms:FuseOps", "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/Debug/Transforms", "@heir//lib/Dialect/Jaxite/IR:Dialect", "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Dialect/LWE/Conversions/LWEToCheddar", "@heir//lib/Dialect/LWE/Conversions/LWEToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", @@ -228,6 +232,7 @@ cc_binary( srcs = ["heir-translate.cpp"], includes = ["include"], deps = [ + "@heir//lib/Target/Cheddar:CheddarEmitter", "@heir//lib/Target/FunctionInfo:FunctionInfoEmitter", "@heir//lib/Target/Jaxite:JaxiteEmitter", "@heir//lib/Target/JaxiteWord:JaxiteWordEmitter", @@ -253,6 +258,7 @@ cc_binary( "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/CGGI/IR:Dialect", "@heir//lib/Dialect/CKKS/IR:Dialect", + "@heir//lib/Dialect/Cheddar/IR:Dialect", "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/Jaxite/IR:Dialect", diff --git a/tools/heir-cheddar.bzl b/tools/heir-cheddar.bzl new file mode 100644 index 0000000000..da2d1a8c2a --- /dev/null +++ b/tools/heir-cheddar.bzl @@ -0,0 +1,53 @@ +"""Macros for building CHEDDAR-based FHE targets from HEIR-generated code.""" + +load("@heir//bazel/cheddar:config.bzl", "requires_cheddar") +load("@heir//tools:heir-translate.bzl", "heir_translate") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +def cheddar_lib( + name, + mlir_src, + heir_translate_flags = ["--emit-cheddar"], + extra_deps = [], + **kwargs): + """Generate a cc_library from HEIR MLIR targeting the CHEDDAR API. + + Args: + name: Name of the generated cc_library target. + mlir_src: The input .mlir file (already lowered to cheddar dialect). + heir_translate_flags: Flags to pass to heir-translate. + extra_deps: Additional cc_library deps. + **kwargs: Additional args forwarded to cc_library. + """ + + # Generate the .cpp file + cpp_name = name + "_cpp" + heir_translate( + name = cpp_name, + src = mlir_src, + pass_flags = heir_translate_flags, + generated_filename = name + ".cpp", + ) + + # Generate the .h file + header_name = name + "_header" + heir_translate( + name = header_name, + src = mlir_src, + pass_flags = ["--emit-cheddar-header"], + generated_filename = name + ".h", + ) + + cc_library( + name = name, + srcs = [":" + cpp_name], + hdrs = [":" + header_name], + deps = [ + "@cheddar//:cheddar", + "@cuda//:cuda_headers", + "@cuda//:cuda_runtime", + "@cuda//:thrust", + ] + extra_deps, + target_compatible_with = requires_cheddar(), + **kwargs + ) diff --git a/tools/heir-lsp.cpp b/tools/heir-lsp.cpp index bffc72d65d..d1c1499451 100644 --- a/tools/heir-lsp.cpp +++ b/tools/heir-lsp.cpp @@ -1,6 +1,7 @@ #include "lib/Dialect/BGV/IR/BGVDialect.h" #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CKKS/IR/CKKSDialect.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Debug/IR/DebugDialect.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" @@ -46,6 +47,7 @@ int main(int argc, char** argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 02879372a4..63cd210d8e 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -16,12 +16,16 @@ #include "lib/Dialect/CGGI/Transforms/Passes.h" #include "lib/Dialect/CKKS/IR/CKKSDialect.h" #include "lib/Dialect/CKKS/Transforms/Passes.h" +#include "lib/Dialect/Cheddar/IR/CheddarDialect.h" +#include "lib/Dialect/Cheddar/Transforms/ConfigureCryptoContext.h" +#include "lib/Dialect/Cheddar/Transforms/FuseOps.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Debug/IR/DebugDialect.h" #include "lib/Dialect/Debug/Transforms/Passes.h" #include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" #include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/LWE/Conversions/LWEToCheddar/LWEToCheddar.h" #include "lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" #include "lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.h" @@ -191,6 +195,7 @@ int main(int argc, char** argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -293,6 +298,8 @@ int main(int argc, char** argv) { registerEmitCInterfacePass(); cggi::registerCGGIPasses(); debug::registerDebugPasses(); + cheddar::registerConfigureCryptoContextPasses(); + cheddar::registerCheddarFuseOpsPasses(); ckks::registerCKKSPasses(); lattigo::registerLattigoPasses(); lwe::registerLWEPasses(); @@ -387,6 +394,7 @@ int main(int argc, char** argv) { // Dialect conversion passes in HEIR bgv::registerBGVToLWEPasses(); + lwe::registerLWEToCheddarPasses(); lwe::registerLWEToLattigoPasses(); lwe::registerLWEToOpenfhePasses(); lwe::registerLWEToPolynomialPasses(); @@ -490,6 +498,11 @@ int main(int argc, char** argv) { "Convert code expressed at FHE scheme level to Lattigo Go code.", toLattigoPipelineBuilder()); + PassPipelineRegistration( + "scheme-to-cheddar", + "Convert code expressed at FHE scheme level to CHEDDAR C++ code.", + toCheddarPipelineBuilder()); + // TODO(#1645): Add backend options for tfhe-rs, fpt, jaxite. PassPipelineRegistration<>( "scheme-to-tfhe-rs", diff --git a/tools/heir-translate.cpp b/tools/heir-translate.cpp index c4e73daae8..5fcfcc423f 100644 --- a/tools/heir-translate.cpp +++ b/tools/heir-translate.cpp @@ -1,3 +1,4 @@ +#include "lib/Target/Cheddar/CheddarEmitter.h" #include "lib/Target/FunctionInfo/FunctionInfoEmitter.h" #include "lib/Target/Jaxite/JaxiteEmitter.h" #include "lib/Target/JaxiteWord/JaxiteWordEmitter.h" @@ -40,6 +41,11 @@ int main(int argc, char** argv) { mlir::heir::openfhe::registerToOpenFhePkeDebugHeaderTranslation(); mlir::heir::openfhe::registerToOpenFhePkeDebugTranslation(); + // CHEDDAR + mlir::heir::cheddar::registerCheddarTranslateOptions(); + mlir::heir::cheddar::registerToCheddarTranslation(); + mlir::heir::cheddar::registerToCheddarHeaderTranslation(); + // Lattigo mlir::heir::lattigo::registerToLattigoTranslation(); mlir::heir::lattigo::registerToLattigoPreprocessingTranslation();