Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/Transforms/Secretize/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ cc_library(
],
deps = [
":pass_inc_gen",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
264 changes: 212 additions & 52 deletions lib/Transforms/Secretize/WrapGeneric.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
#include <utility>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Secret/IR/SecretDialect.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Transforms/Secretize/Passes.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project

namespace mlir {
Expand All @@ -29,8 +32,8 @@ namespace heir {
#include "lib/Transforms/Secretize/Passes.h.inc"

struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
WrapWithGeneric(mlir::MLIRContext* context)
: mlir::OpRewritePattern<func::FuncOp>(context) {}
WrapWithGeneric(mlir::MLIRContext* context, DataFlowSolver* solver)
: mlir::OpRewritePattern<func::FuncOp>(context), solver(solver) {}

LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter& rewriter) const override {
Expand Down Expand Up @@ -58,54 +61,203 @@ struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
return rewriter.notifyMatchFailure(op, "no secret inputs found");
}

auto newOutputs = llvm::to_vector<6>(llvm::map_range(
op.getResultTypes(),
[](Type t) -> Type { return secret::SecretType::get(t); }));
// Externally defined functions have no body - conservatively wrap all
// outputs
if (op.isDeclaration()) {
SmallVector<Type, 6> newOutputs;
for (Type resultType : op.getResultTypes()) {
newOutputs.push_back(secret::SecretType::get(resultType));
}
rewriter.modifyOpInPlace(op, [&] {
op.setFunctionType(
FunctionType::get(getContext(), {newInputs}, {newOutputs}));
});
return success();
}

// Phase 1: Identify which operations depend on secrets
Block& opEntryBlock = op.getRegion().front();
auto* returnOp = opEntryBlock.getTerminator();

// Track which values are secret (including block arguments)
llvm::DenseSet<Value> secretValues;
for (unsigned i = 0; i < op.getNumArguments(); i++) {
if (isSecret(op.getArgument(i), solver)) {
secretValues.insert(op.getArgument(i));
}
}

// Track which operations are secret-dependent
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I think instead of cloning only secret ops in, I think this particular pattern should just be wrapping the entire block in a generic, but modifying just the selective output wrapping that you did with newOutputs.

There are other patterns that would also benefit from taking the solver in this file, but maybe we can go pattern by pattern to integrate the secretness analysis (for example, HoistPlaintextOps would benefit, and that would hoist plaintext ops outside of the secret body if they can be hoisted).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the core issue is that if you secret.yield a plaintext SSA value, it becomes secret according to the secretness analysis and won't be automatically converted to a public value. We had this special behavior back in the CGGI pipeline because we wanted to make an empty memref secret as the initializer for a loop that put secret items inside it.

Maybe the relevant pattern could have an option to control its behavior, and we could have this pass specialize... Thoughts?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!op.getOps<memref::AllocOp>().empty()) {

Right - that's why I think this pattern should also take secretness analysis and we should test if that allocated value will be used for secret storing values later. (That being said, I think that would mean secretness analysis would need a backwards analysis as well). But I think that's why I'd prefer this pass stay minimal and tackle that problem in a later PR

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we agree, but that would make this issue unresolvable until the secretness analysis is improved.

llvm::DenseSet<Operation*> secretOps;
for (Operation& bodyOp : opEntryBlock) {
if (&bodyOp == returnOp) continue;

// modification to function type should go through the rewriter
// An operation is secret if any of its operands are secret
bool isSecretOp = llvm::any_of(bodyOp.getOperands(), [&](Value operand) {
return secretValues.contains(operand) || isSecret(operand, solver);
});

if (isSecretOp) {
secretOps.insert(&bodyOp);
// All results of a secret op become secret
for (Value result : bodyOp.getResults()) {
secretValues.insert(result);
}
}
}

// Phase 2: Determine output types and which outputs need to be in generic
SmallVector<Type, 6> newOutputs;
SmallVector<Value> secretReturnValues;
SmallVector<Value> plaintextReturnValues;
SmallVector<unsigned> secretReturnIndices;
SmallVector<unsigned> plaintextReturnIndices;

for (auto [i, resultType] : llvm::enumerate(op.getResultTypes())) {
Value returnVal = returnOp->getOperand(i);
if (secretValues.contains(returnVal) || isSecret(returnVal, solver)) {
newOutputs.push_back(secret::SecretType::get(resultType));
secretReturnValues.push_back(returnVal);
secretReturnIndices.push_back(i);
} else {
newOutputs.push_back(resultType);
plaintextReturnValues.push_back(returnVal);
plaintextReturnIndices.push_back(i);
}
}

// Modification to function type should go through the rewriter
rewriter.modifyOpInPlace(op, [&] {
op.setFunctionType(
FunctionType::get(getContext(), {newInputs}, {newOutputs}));
});

// Externally defined functions have no body
if (op.isDeclaration()) {
// If there are no secret-dependent operations AND no secret return values,
// we don't need a generic at all (purely plaintext function).
// But if there are secret return values (e.g., function directly returns
// its secret input), we still need a generic even with no operations.
if (secretOps.empty() && secretReturnValues.empty()) {
// Purely plaintext function - no generic needed
return success();
}
// Create a new block where we will insert the new secret.generic and move
// the function ops into.
Block& opEntryBlock = op.getRegion().front();

// Phase 3: Collect inputs for the secret.generic block
// These are: (1) secret arguments, (2) plaintext values used by secret ops
SmallVector<Value> genericInputs;
SmallVector<Type> genericInputTypes;

// Add all function arguments that are used by secret ops (or are secret)
for (unsigned i = 0; i < op.getNumArguments(); i++) {
genericInputs.push_back(op.getArgument(i));
genericInputTypes.push_back(op.getArgument(i).getType());
}

// Collect plaintext-defined values that are used inside secret ops
SmallVector<Value> plaintextValuesUsedInGeneric;
for (Operation* secretOp : secretOps) {
for (Value operand : secretOp->getOperands()) {
// If the operand is from outside the secretOps set (i.e., plaintext)
if (!secretValues.contains(operand)) {
Operation* defOp = operand.getDefiningOp();
// It's a plaintext value defined by a non-secret op in this function
if (defOp && !secretOps.contains(defOp) &&
defOp->getParentRegion() == &op.getRegion()) {
if (!llvm::is_contained(plaintextValuesUsedInGeneric, operand)) {
plaintextValuesUsedInGeneric.push_back(operand);
genericInputs.push_back(operand);
genericInputTypes.push_back(operand.getType());
}
}
}
}
}

// Phase 4: Build the secret.generic with only secret ops
SmallVector<Type> genericOutputTypes;
for (Value v : secretReturnValues) {
genericOutputTypes.push_back(secret::SecretType::get(v.getType()));
}

// Create a new block for the rewritten function
auto* newBlock = rewriter.createBlock(
&opEntryBlock, opEntryBlock.getArgumentTypes(),
SmallVector<Location>(opEntryBlock.getNumArguments(), op.getLoc()));

rewriter.setInsertionPointToStart(newBlock);

// Build mapping from old block args to new block args
IRMapping outerMapping;
for (unsigned i = 0; i < opEntryBlock.getNumArguments(); ++i) {
outerMapping.map(opEntryBlock.getArgument(i), newBlock->getArgument(i));
}

// Clone plaintext operations to the new block (before the generic)
for (Operation& bodyOp : opEntryBlock) {
if (&bodyOp == returnOp) continue;
if (!secretOps.contains(&bodyOp)) {
Operation* clonedOp = rewriter.clone(bodyOp, outerMapping);
for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) {
outerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i));
}
}
}

// Update genericInputs to use the new block's values
SmallVector<Value> mappedGenericInputs;
for (Value v : genericInputs) {
mappedGenericInputs.push_back(outerMapping.lookupOrDefault(v));
}

// Now create the secret.generic
auto newGeneric = secret::GenericOp::create(
rewriter, op.getLoc(), op.getArguments(), newOutputs,
rewriter, op.getLoc(), mappedGenericInputs, genericOutputTypes,
[&](OpBuilder& b, Location loc, ValueRange blockArguments) {
// Map the input values to the block arguments.
IRMapping mp;
for (unsigned i = 0; i < blockArguments.size(); ++i) {
mp.map(opEntryBlock.getArgument(i), blockArguments[i]);
// Map inputs to block arguments
IRMapping innerMapping;
for (unsigned i = 0; i < genericInputs.size(); ++i) {
innerMapping.map(genericInputs[i], blockArguments[i]);
}

// Clone only secret operations into the generic
for (Operation& bodyOp : opEntryBlock) {
if (&bodyOp == returnOp) continue;
if (secretOps.contains(&bodyOp)) {
Operation* clonedOp = b.clone(bodyOp, innerMapping);
for (unsigned i = 0; i < bodyOp.getNumResults(); ++i) {
innerMapping.map(bodyOp.getResult(i), clonedOp->getResult(i));
}
}
}

auto* returnOp = opEntryBlock.getTerminator();
secret::YieldOp::create(b, loc,
llvm::to_vector(llvm::map_range(
returnOp->getOperands(), [&](Value v) {
return mp.lookupOrDefault(v);
})));
returnOp->erase();
// Yield only the secret return values
SmallVector<Value> yieldValues;
for (Value v : secretReturnValues) {
yieldValues.push_back(innerMapping.lookupOrDefault(v));
}
secret::YieldOp::create(b, loc, yieldValues);
});

Block& genericBlock = newGeneric.getRegion().front();
rewriter.inlineBlockBefore(&opEntryBlock,
&genericBlock.getOperations().back(),
genericBlock.getArguments());
func::ReturnOp::create(rewriter, op.getLoc(), newGeneric.getResults());
// Build the final return values in the correct order
SmallVector<Value> finalReturnValues(op.getNumResults());
unsigned secretResultIdx = 0;
for (unsigned idx : secretReturnIndices) {
finalReturnValues[idx] = newGeneric.getResult(secretResultIdx++);
}
for (unsigned idx : plaintextReturnIndices) {
Value returnVal = returnOp->getOperand(idx);
finalReturnValues[idx] = outerMapping.lookupOrDefault(returnVal);
}

func::ReturnOp::create(rewriter, op.getLoc(), finalReturnValues);

// Erase the old block
rewriter.eraseBlock(&opEntryBlock);

return success();
}

private:
DataFlowSolver* solver;
};

struct ConvertFuncCall : public OpRewritePattern<func::CallOp> {
Expand Down Expand Up @@ -159,20 +311,28 @@ struct WrapGeneric : impl::WrapGenericBase<WrapGeneric> {
using WrapGenericBase::WrapGenericBase;

void detectSecretGeneric() {
bool hasSecretGeneric = false;
getOperation().walk([&](secret::GenericOp op) { hasSecretGeneric = true; });
if (!hasSecretGeneric) {
getOperation().emitWarning(
"No secret found in the module. Did you forget to annotate "
"{secret.secret} to the function arguments?");
}
// Note: Since we now correctly handle functions that return only
// plaintext values (which don't get a secret.generic), we should not
// warn about missing secret.generic ops. The warning was intended
// for the case where users forgot to annotate secret inputs, but that
// is already caught by the hasSecrets check in WrapWithGeneric.
}

void runOnOperation() override {
MLIRContext* context = &getContext();

// Run SecretnessAnalysis to determine which values depend on secrets
DataFlowSolver solver;
dataflow::loadBaselineAnalyses(solver);
solver.load<SecretnessAnalysis>();
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run SecretnessAnalysis.\n";
signalPassFailure();
return;
}

mlir::RewritePatternSet patterns(context);
patterns.add<WrapWithGeneric>(context);
patterns.add<WrapWithGeneric>(context, &solver);
(void)walkAndApplyPatterns(getOperation(), std::move(patterns));

// func.call should be converted after callee func type updated
Expand Down
17 changes: 17 additions & 0 deletions tests/Dialect/Secret/Transforms/wrap_generic/wrap_generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,20 @@ module {
return %alloc : memref<1x80xi8>
}
}

// -----

// Regression test for issue #2553: plaintext constant should not become secret
// When a function only returns values that don't depend on secrets,
// no secret.generic should be created.
module {
// CHECK: @plaintext_output(%arg0: !secret.secret<i32>) -> i8
func.func @plaintext_output(%x: i32 {secret.secret}) -> i8 {
// The constant does not depend on the secret input
// CHECK-NOT: secret.generic
%0 = arith.constant 42 : i8
// CHECK: return %{{.*}} : i8
func.return %0 : i8
}
}