Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/Dialect/LWE/Conversions/LWEToLattigo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ cc_library(
"@heir//lib/Dialect/CKKS/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Lattigo/IR:Dialect",
"@heir//lib/Dialect/Orion/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
Expand Down
181 changes: 174 additions & 7 deletions lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.h"

#include <cassert>
#include <cstdint>
#include <utility>
#include <vector>

#include "lib/Dialect/BGV/IR/BGVDialect.h"
#include "lib/Dialect/BGV/IR/BGVOps.h"
#include "lib/Dialect/CKKS/IR/CKKSAttributes.h"
#include "lib/Dialect/CKKS/IR/CKKSDialect.h"
#include "lib/Dialect/CKKS/IR/CKKSOps.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
Expand All @@ -17,9 +17,12 @@
#include "lib/Dialect/Lattigo/IR/LattigoOps.h"
#include "lib/Dialect/Lattigo/IR/LattigoTypes.h"
#include "lib/Dialect/ModuleAttributes.h"
#include "lib/Dialect/Orion/IR/OrionDialect.h"
#include "lib/Dialect/Orion/IR/OrionOps.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand All @@ -34,6 +37,8 @@
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "lwe-to-lattigo"

namespace mlir::heir::lwe {

class ToLattigoTypeConverter : public TypeConverter {
Expand Down Expand Up @@ -65,8 +70,9 @@ FailureOr<Value> getContextualEvaluator(Operation* op) {
auto result = getContextualArgFromFunc<EvaluatorType>(op);
if (failed(result)) {
return op->emitOpError()
<< "Found RLWE op in a function without a public "
"key argument. Did the AddEvaluatorArg pattern fail to run?";
<< "Found RLWE op in a function without a needed evaluator "
"argument. Did the AddEvaluatorArg pattern fail to run "
"for the evaluator needed by this op?";
}
return result.value();
}
Expand All @@ -75,6 +81,29 @@ FailureOr<Value> getContextualEvaluator(Operation* op, Type type) {
return getContextualArgFromFunc(op, type);
}

// Find the unique operation in the current func whose result has type Ty or
// return Failure.
template <typename Ty>
FailureOr<TypedValue<Ty>> findUniqueOpResult(Operation* op) {
TypedValue<Ty> foundValue;
bool found = false;
func::FuncOp funcOp = op->getParentOfType<func::FuncOp>();
funcOp->walk([&](Operation* innerOp) {
for (auto result : innerOp->getResults()) {
if (mlir::isa<Ty>(result.getType())) {
foundValue = cast<TypedValue<Ty>>(result);
found = true;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (found) {
return foundValue;
}
return failure();
}

// NOTE: we can not use containsDialect
// for FuncOp declaration, which does not have a body
template <typename... Dialects>
Expand Down Expand Up @@ -105,8 +134,13 @@ struct AddEvaluatorArg : public OpConversionPattern<func::FuncOp> {
SmallVector<Type, 4> selectedEvaluators;

for (const auto& evaluator : evaluators) {
LLVM_DEBUG(llvm::dbgs()
<< "Checking if evaluator should be added of type: "
<< evaluator.first << "\n");
auto predicate = evaluator.second;
if (predicate(op)) {
LLVM_DEBUG(llvm::dbgs()
<< "Adding evaluator of type: " << evaluator.first << "\n");
selectedEvaluators.push_back(evaluator.first);
}
}
Expand Down Expand Up @@ -352,7 +386,9 @@ struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
ConversionPatternRewriter& rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result)) return result;
if (failed(result))
return rewriter.notifyMatchFailure(op,
"Failed to get contextual evaluator");

Value evaluator = result.value();
rewriter.replaceOp(
Expand All @@ -364,6 +400,30 @@ struct ConvertRlweRotateOp : public OpConversionPattern<RlweRotateOp> {
}
};

template <typename EvaluatorType, typename RlweBootstrapOp,
typename LattigoBootstrapOp>
struct ConvertRlweBootstrapOp : public OpConversionPattern<RlweBootstrapOp> {
using OpConversionPattern<RlweBootstrapOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
RlweBootstrapOp op, typename RlweBootstrapOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
FailureOr<Value> result =
getContextualEvaluator<EvaluatorType>(op.getOperation());
if (failed(result))
return rewriter.notifyMatchFailure(op,
"Failed to get contextual evaluator");

Value evaluator = result.value();
rewriter.replaceOp(
op, LattigoBootstrapOp::create(
rewriter, op.getLoc(),
this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getInput()));
return success();
}
};

template <typename EvaluatorType, typename LevelReduceOp,
typename LattigoLevelReduceOp>
struct ConvertRlweLevelReduceOp : public OpConversionPattern<LevelReduceOp> {
Expand Down Expand Up @@ -471,6 +531,104 @@ struct ConvertLWEReinterpretApplicationData
}
};

// Orion conversions
struct ConvertOrionLinearTransformOp
: public OpConversionPattern<orion::LinearTransformOp> {
using OpConversionPattern<orion::LinearTransformOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
orion::LinearTransformOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
FailureOr<Value> evaluatorResult =
getContextualEvaluator<lattigo::CKKSEvaluatorType>(op.getOperation());
if (failed(evaluatorResult)) {
return op.emitOpError() << "CKKS evaluator not found in function context";
}
Value evaluator = evaluatorResult.value();

FailureOr<Value> encoderResult =
getContextualEvaluator<lattigo::CKKSEncoderType>(op.getOperation());
if (failed(encoderResult)) {
return op.emitOpError() << "CKKS encoder not found in function context";
}
Value encoder = encoderResult.value();

auto bsgsRatio = op.getBsgsRatioAttr();
int64_t logBsgsRatio =
static_cast<int64_t>(cast<FloatAttr>(bsgsRatio).getValueAsDouble());
auto logBsgsRatioAttr = rewriter.getI64IntegerAttr(logBsgsRatio);

rewriter.replaceOpWithNewOp<lattigo::CKKSLinearTransformOp>(
op, this->typeConverter->convertType(op.getResult().getType()),
evaluator, encoder, adaptor.getInput(), adaptor.getDiagonals(),
adaptor.getDiagonalIndices(), op.getOrionLevelAttr(), logBsgsRatioAttr);

return success();
}
};

struct ConvertOrionChebyshevOp
: public OpConversionPattern<orion::ChebyshevOp> {
using OpConversionPattern<orion::ChebyshevOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
orion::ChebyshevOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "Lowering Orion ChebyshevOp\n");
// Get or create the polynomial evaluator from the function context
FailureOr<Value> evaluatorResult =
findUniqueOpResult<lattigo::CKKSPolynomialEvaluatorType>(
op.getOperation());
Value polyEvaluator;
if (failed(evaluatorResult)) {
LLVM_DEBUG(llvm::dbgs() << "Creating new CKKS polynomial evaluator\n");
FailureOr<Value> evaluatorResult =
getContextualEvaluator<lattigo::CKKSEvaluatorType>(op.getOperation());
if (failed(evaluatorResult)) {
return rewriter.notifyMatchFailure(
op, "CKKS evaluator not found in function context");
}
Value evaluator = evaluatorResult.value();

FailureOr<Value> result2 =
getContextualEvaluator<lattigo::CKKSParameterType>(op.getOperation());
if (failed(result2))
return rewriter.notifyMatchFailure(
op, "Failed to get contextual CKKS parameters");
Value params = result2.value();

// Insert op at start of func so it's easy to find later
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(
&op->getParentOfType<func::FuncOp>().getBody().front());
auto evaluatorOp = lattigo::CKKSNewPolynomialEvaluatorOp::create(
rewriter, op.getLoc(),
lattigo::CKKSPolynomialEvaluatorType::get(rewriter.getContext()),
params, evaluator);
polyEvaluator = evaluatorOp.getResult();
}
} else {
polyEvaluator = evaluatorResult.value();
}

// Orion always uses the logDefaultScale for the target scale
ckks::SchemeParamAttr schemeParams =
cast<ckks::SchemeParamAttr>(getSchemeParamAttr(op));
IntegerAttr defaultScale = rewriter.getIntegerAttr(
rewriter.getI64Type(), 1L << schemeParams.getLogDefaultScale());
LLVM_DEBUG(llvm::dbgs()
<< "Using default scale: " << defaultScale.getInt() << "\n");

auto chebyshevOp = lattigo::CKKSChebyshevOp::create(
rewriter, op.getLoc(), adaptor.getInput().getType(), polyEvaluator,
adaptor.getInput(), adaptor.getCoefficients(), defaultScale);
rewriter.replaceOp(op, chebyshevOp.getResult());

return success();
}
};

} // namespace

// BGV
Expand Down Expand Up @@ -549,6 +707,10 @@ using ConvertCKKSRotateOp =
ConvertRlweRotateOp<lattigo::CKKSEvaluatorType, ckks::RotateOp,
lattigo::CKKSRotateNewOp>;

using ConvertCKKSBootstrapOp =
ConvertRlweBootstrapOp<lattigo::CKKSBootstrapperType, ckks::BootstrapOp,
lattigo::CKKSBootstrapOp>;

using ConvertCKKSEncryptOp =
ConvertRlweUnaryOp<lattigo::RLWEEncryptorType, lwe::RLWEEncryptOp,
lattigo::RLWEEncryptOp>;
Expand Down Expand Up @@ -612,7 +774,8 @@ struct LWEToLattigo : public impl::LWEToLattigoBase<LWEToLattigo> {

ConversionTarget target(*context);
target.addLegalDialect<lattigo::LattigoDialect>();
target.addIllegalDialect<bgv::BGVDialect, ckks::CKKSDialect>();
target.addIllegalDialect<bgv::BGVDialect, ckks::CKKSDialect,
orion::OrionDialect>();
target
.addIllegalOp<lwe::RLWEEncryptOp, lwe::RLWEDecryptOp, lwe::RLWEEncodeOp,
lwe::RLWEDecodeOp, lwe::RAddOp, lwe::RSubOp, lwe::RMulOp,
Expand Down Expand Up @@ -733,6 +896,9 @@ struct LWEToLattigo : public impl::LWEToLattigoBase<LWEToLattigo> {
{lattigo::CKKSEncoderType::get(context),
gateByCKKSModuleAttr(
containsArgumentOfDialect<lwe::LWEDialect, ckks::CKKSDialect>)},
{lattigo::CKKSBootstrapperType::get(context),
gateByCKKSModuleAttr(
containsArgumentOfDialect<lwe::LWEDialect, ckks::CKKSDialect>)},
{lattigo::RLWEEncryptorType::get(context, /*publicKey*/ true),
containsArgumentOfType<lwe::LWEPublicKeyType>},
// for LWESecretKey, if its uses are encrypt, then convert it to an
Expand Down Expand Up @@ -760,8 +926,9 @@ struct LWEToLattigo : public impl::LWEToLattigoBase<LWEToLattigo> {
ConvertCKKSAddPlainOp, ConvertCKKSSubPlainOp, ConvertCKKSMulPlainOp,
ConvertCKKSRelinOp, ConvertCKKSModulusSwitchOp, ConvertCKKSRotateOp,
ConvertCKKSEncryptOp, ConvertCKKSDecryptOp, ConvertCKKSEncodeOp,
ConvertCKKSDecodeOp, ConvertCKKSLevelReduceOp>(typeConverter,
context);
ConvertCKKSDecodeOp, ConvertCKKSLevelReduceOp, ConvertCKKSBootstrapOp,
ConvertOrionLinearTransformOp, ConvertOrionChebyshevOp>(typeConverter,
context);
}
// Misc
patterns.add<ConvertLWEReinterpretApplicationData>(typeConverter, context);
Expand Down
92 changes: 92 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoCKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,96 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInplaceOp<"rotate"> {
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_CKKSLinearTransformOp : Lattigo_CKKSOp<"linear_transform"> {
let summary = "Apply a linear transform on a lattigo CKKS ciphertext";
let description = [{
This operation applies a linear transform on a CKKS ciphertext using
the provided float diagonals.

The linear transform is defined by a set of diagonals, where each diagonal
represents a specific shift and scaling of the input ciphertext slots.

The `diagonals` input is a 2D tensor where each row represents one non-zero
diagonal of the square matrix to evaluate. The diagonal values are floats
that will be encoded into plaintexts during code generation.

The `levelQ` attribute specifies the modulus level at which the operation
should be performed.

The `logBabyStepGiantStepRatio` attribute is used to optimize the linear
transformation using the baby-step giant-step algorithm. It defines the
ratio between the sizes of the baby steps and giant steps. If unset,
it is zero by default.

During code generation, this op will:
1. Create a lintrans.Diagonals map from the input tensor
2. Create and encode a lintrans.Transformation
3. Create a lintrans.Evaluator
4. Evaluate the transformation on the input ciphertext
}];
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_CKKSEncoder:$encoder,
Lattigo_RLWECiphertext:$input,
// Parameters corresponding to the lattigo lintrans.Parameters struct,
// and those that can be inferred from the IR are commented out.

RankedTensorOf<[AnyFloat]>:$diagonals,
DenseI32ArrayAttr:$diagonal_indices,
Builtin_IntegerAttr:$levelQ,

// The same auxiliary prime used to generate evaluation keys
// Builtin_IntegerAttr:$levelP,

// For CKKS these are hard-coded to [1, N/2] for N = ring modulus degree
// Builtin_IntegerAttr:$logDimensionsRows,
// Builtin_IntegerAttr:$logDimensionsCols,

Builtin_IntegerAttr:$logBabyStepGiantStepRatio
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_CKKSNewPolynomialEvaluatorOp : Lattigo_CKKSOp<"new_polynomial_evaluator"> {
let summary = "Create a new polynomial evaluator in the Lattigo CKKS dialect";
let description = [{
This operation creates a new evaluator for evaluating polynomials in the
Lattigo CKKS dialect.
}];
let arguments = (ins
Lattigo_CKKSParameter:$params,
Lattigo_CKKSEvaluator:$evaluator
);
let results = (outs Lattigo_CKKSPolynomialEvaluator:$polynomialEvaluator);
}


def Lattigo_CKKSChebyshevOp : Lattigo_CKKSOp<"chebyshev", [AllTypesMatch<["ciphertext", "output"]>]> {
let summary = "Evaluate a chebyshev polynomial on a lattigo CKKS ciphertext";
let description = [{
This operation evaluates a chebyshev polynomial on a CKKS ciphertext using
the Lattigo polynomial evaluator API.

The codegen constructs the Lattigo polynomial and, using an existing polynomial
evaluator, calls the evaluation routine with a given target scale.
}];
let arguments = (ins
Lattigo_CKKSPolynomialEvaluator:$evaluator,
Lattigo_RLWECiphertext:$ciphertext,
ArrayAttr:$coefficients,
Builtin_IntegerAttr:$targetScale
);
let results = (outs Lattigo_RLWECiphertext:$output);
}


def Lattigo_CKKSBootstrapOp : Lattigo_CKKSOp<"bootstrap"> {
let summary = "Bootstrap a lattigo CKKS ciphertext";
let arguments = (ins
Lattigo_CKKSBootstrapper:$bootstrapper,
Lattigo_RLWECiphertext:$ciphertext
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

#endif // LIB_DIALECT_LATTIGO_IR_LATTIGOCKKSOPS_TD_
Loading
Loading