Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/Analysis/LevelAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
"@heir//lib/Dialect:ModuleAttributes",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Target/CompilationTarget",
"@heir//lib/Utils",
"@heir//lib/Utils:AttributeUtils",
"@llvm-project//llvm:Support",
Expand Down
9 changes: 7 additions & 2 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/ModuleAttributes.h"
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Target/CompilationTarget/CompilationTarget.h"
#include "lib/Utils/AttributeUtils.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
Expand Down Expand Up @@ -96,12 +97,16 @@ LevelState transferForward(mgmt::LevelReduceMinOp op,

LevelState transferForward(mgmt::BootstrapOp op,
ArrayRef<const LevelLattice*> operands) {
auto module = op->getParentOfType<ModuleOp>();
const CompilationTarget* target = getTargetConfig(module);
int levelsConsumed = target ? target->bootstrapLevelsConsumed : 0;

LevelState result = std::visit(
Overloaded{
[](MaxLevel) -> LevelState { return LevelState(0); },
[=](MaxLevel) -> LevelState { return LevelState(levelsConsumed); },
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
[](int val) -> LevelState { return LevelState(0); },
[=](int val) -> LevelState { return LevelState(levelsConsumed); },
},
operands[0]->getValue().get());
LLVM_DEBUG(debugLog("bootstrap", operands, result));
Expand Down
27 changes: 27 additions & 0 deletions lib/Tablegen/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "CompilationTargetEmitter",
srcs = ["CompilationTargetEmitter.cpp"],
hdrs = ["CompilationTargetEmitter.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
],
)

cc_binary(
name = "heir-tblgen",
srcs = ["TablegenMain.cpp"],
deps = [
":CompilationTargetEmitter",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
],
)
28 changes: 28 additions & 0 deletions lib/Tablegen/CompilationTargetEmitter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "lib/Tablegen/CompilationTargetEmitter.h"

#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "llvm/include/llvm/TableGen/Record.h" // from @llvm-project
#include "llvm/include/llvm/TableGen/TableGenBackend.h" // from @llvm-project

namespace mlir {
namespace heir {

bool emitCompilationTargetRegistration(const llvm::RecordKeeper& records,
llvm::raw_ostream& os) {
auto targets = records.getAllDerivedDefinitions("CompilationTarget");

os << "CompilationTargetRegistry::CompilationTargetRegistry() {\n";
for (auto* target : targets) {
auto backendName = target->getValueAsString("backendName");
auto bootstrapLevelsConsumed =
target->getValueAsInt("bootstrapLevelsConsumed");

os << " targets[\"" << backendName << "\"] = CompilationTarget{\""
<< backendName << "\", " << (int)bootstrapLevelsConsumed << "};\n";
}
os << "}\n";
return false;
}

} // namespace heir
} // namespace mlir
15 changes: 15 additions & 0 deletions lib/Tablegen/CompilationTargetEmitter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIB_TABLEGEN_COMPILATIONTARGETEMITTER_H_
#define LIB_TABLEGEN_COMPILATIONTARGETEMITTER_H_

#include "llvm/include/llvm/TableGen/Record.h" // from @llvm-project

namespace mlir {
namespace heir {

bool emitCompilationTargetRegistration(const llvm::RecordKeeper& records,
llvm::raw_ostream& os);

} // namespace heir
} // namespace mlir

#endif // LIB_TABLEGEN_COMPILATIONTARGETEMITTER_H_
35 changes: 35 additions & 0 deletions lib/Tablegen/TablegenMain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "lib/Tablegen/CompilationTargetEmitter.h"
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
#include "llvm/include/llvm/Support/InitLLVM.h" // from @llvm-project
#include "llvm/include/llvm/TableGen/Main.h" // from @llvm-project
#include "llvm/include/llvm/TableGen/Record.h" // from @llvm-project

using namespace mlir;
using namespace heir;

enum ActionType {
None,
GenCompilationTargetRegistration,
};

static llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(GenCompilationTargetRegistration,
"gen-compilation-target-registration",
"Generate compilation target registration")));

bool heirTableGenMain(llvm::raw_ostream& os,
const llvm::RecordKeeper& records) {
switch (action) {
case GenCompilationTargetRegistration:
return emitCompilationTargetRegistration(records, os);
default:
return false;
}
}

int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return llvm::TableGenMain(argv[0], &heirTableGenMain);
}
45 changes: 45 additions & 0 deletions lib/Target/CompilationTarget/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "CompilationTarget",
srcs = ["CompilationTarget.cpp"],
hdrs = ["CompilationTarget.h"],
deps = [
":compilation_target_inc_gen",
"@heir//lib/Dialect:ModuleAttributes",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)

td_library(
name = "td_files",
srcs = ["HEIRTarget.td"],
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
],
)

gentbl_cc_library(
name = "compilation_target_inc_gen",
tbl_outs = [
(
["-gen-compilation-target-registration"],
"CompilationTarget.cpp.inc",
),
],
tblgen = "@heir//lib/Tablegen:heir-tblgen",
td_file = "HEIRTarget.td",
deps = [
":td_files",
],
)

exports_files(["HEIRTarget.td"])
38 changes: 38 additions & 0 deletions lib/Target/CompilationTarget/CompilationTarget.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "lib/Target/CompilationTarget/CompilationTarget.h"

#include "lib/Dialect/ModuleAttributes.h"
#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project

namespace mlir {
namespace heir {

#include "lib/Target/CompilationTarget/CompilationTarget.cpp.inc"

CompilationTargetRegistry& CompilationTargetRegistry::getInstance() {
static CompilationTargetRegistry instance;
return instance;
}

const CompilationTarget* CompilationTargetRegistry::get(llvm::StringRef name) {
auto& instance = getInstance();
auto it = instance.targets.find(name);
if (it == instance.targets.end()) {
return nullptr;
}
return &it->second;
}

const CompilationTarget* getTargetConfig(ModuleOp module) {
for (auto attr : module->getAttrs()) {
llvm::StringRef name = attr.getName().strref();
if (name.consume_front("backend.")) {
return CompilationTargetRegistry::get(name);
}
}
return nullptr;
}

} // namespace heir
} // namespace mlir
34 changes: 34 additions & 0 deletions lib/Target/CompilationTarget/CompilationTarget.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef LIB_TARGET_COMPILATIONTARGET_COMPILATIONTARGET_H_
#define LIB_TARGET_COMPILATIONTARGET_COMPILATIONTARGET_H_

#include <string>

#include "llvm/include/llvm/ADT/StringMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project

namespace mlir {
namespace heir {

struct CompilationTarget {
std::string backendName;
int bootstrapLevelsConsumed;
};

class CompilationTargetRegistry {
public:
static const CompilationTarget* get(llvm::StringRef name);

private:
CompilationTargetRegistry();
static CompilationTargetRegistry& getInstance();

llvm::StringMap<CompilationTarget> targets;
};

const CompilationTarget* getTargetConfig(ModuleOp module);

} // namespace heir
} // namespace mlir

#endif // LIB_TARGET_COMPILATIONTARGET_COMPILATIONTARGET_H_
12 changes: 12 additions & 0 deletions lib/Target/CompilationTarget/HEIRTarget.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class CompilationTarget<string name> {
string backendName = name;
int bootstrapLevelsConsumed = 0;
}

def OpenFHE : CompilationTarget<"openfhe"> {
let bootstrapLevelsConsumed = 3;
}

def Lattigo : CompilationTarget<"lattigo"> {
let bootstrapLevelsConsumed = 1;
}
17 changes: 13 additions & 4 deletions lib/Transforms/AnnotateModule/AnnotateModule.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "lib/Transforms/AnnotateModule/AnnotateModule.h"

#include "lib/Dialect/ModuleAttributes.h"
#include "lib/Target/CompilationTarget/CompilationTarget.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

Expand All @@ -26,10 +27,18 @@ struct AnnotateModule : impl::AnnotateModuleBase<AnnotateModule> {
moduleSetCGGI(module);
}

if (backend == "openfhe") {
moduleSetOpenfhe(module);
} else if (backend == "lattigo") {
moduleSetLattigo(module);
if (!backend.empty()) {
if (!CompilationTargetRegistry::get(backend)) {
module.emitError() << "Unknown backend: " << backend;
signalPassFailure();
return;
}

if (backend == "openfhe") {
moduleSetOpenfhe(module);
} else if (backend == "lattigo") {
moduleSetLattigo(module);
}
}
}
};
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/AnnotateModule/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
deps = [
":pass_inc_gen",
"@heir//lib/Dialect:ModuleAttributes",
"@heir//lib/Target/CompilationTarget",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
Expand Down
19 changes: 19 additions & 0 deletions tests/Dialect/Mgmt/Transforms/bootstrap_levels.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: heir-opt --annotate-module="backend=openfhe" --annotate-mgmt %s | FileCheck %s --check-prefix=CHECK-OPENFHE
// RUN: heir-opt --annotate-module="backend=lattigo" --annotate-mgmt %s | FileCheck %s --check-prefix=CHECK-LATTIGO

func.func @main(%arg0: !secret.secret<tensor<8xi8>>) -> !secret.secret<tensor<8xi8>> {
%b = secret.generic(%arg0: !secret.secret<tensor<8xi8>>) {
^body(%clear_a: tensor<8xi8>):
%c = mgmt.bootstrap %clear_a : tensor<8xi8>
secret.yield %c : tensor<8xi8>
} -> !secret.secret<tensor<8xi8>>
func.return %b : !secret.secret<tensor<8xi8>>
}

// CHECK-OPENFHE: func.func @main(%arg0: !secret.secret<tensor<8xi8>> {mgmt.mgmt = #mgmt.mgmt<level = 3>})
// CHECK-OPENFHE: mgmt.bootstrap
// CHECK-OPENFHE-SAME: {mgmt.mgmt = #mgmt.mgmt<level = 0>{{.*}}}

// CHECK-LATTIGO: func.func @main(%arg0: !secret.secret<tensor<8xi8>> {mgmt.mgmt = #mgmt.mgmt<level = 1>})
// CHECK-LATTIGO: mgmt.bootstrap
// CHECK-LATTIGO-SAME: {mgmt.mgmt = #mgmt.mgmt<level = 0>{{.*}}}
5 changes: 5 additions & 0 deletions tests/Transforms/annotate_module/invalid_backend.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// RUN: heir-opt --annotate-module="backend=invalid_backend" --verify-diagnostics %s

// expected-error @+1 {{Unknown backend: invalid_backend}}
module {
}
Loading