Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions lib/Dialect/Polynomial/IR/PolynomialOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def Polynomial_MonicMonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypes
let results = (outs PolynomialLike:$output);
}

def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure, SameOperandsAndResultForm]> {
let summary = "Creates a polynomial from integer coefficients or evaluations stored in a tensor.";
let description = [{
`polynomial.from_tensor` creates a polynomial value from a tensor of
Expand Down Expand Up @@ -236,7 +236,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
let hasVerifier = 1;
}

def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure, SameOperandsAndResultForm]> {
let summary = "Creates a tensor containing the coefficients or evaluations of a polynomial.";
let description = [{
`polynomial.to_tensor` creates a dense tensor value containing the
Expand Down Expand Up @@ -291,7 +291,6 @@ def Polynomial_ModSwitchOp : Polynomial_Op<"mod_switch", [Pure, SameOperandsAndR
let hasVerifier = 1;
}


def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
Polynomial_TypedFloatPolynomialAttr,
Polynomial_TypedIntPolynomialAttr,
Expand Down Expand Up @@ -441,7 +440,7 @@ def Polynomial_YieldOp : Polynomial_Op<"yield", [Terminator, HasParent<"ApplyCoe
}

def Polynomial_ApplyCoefficientwiseOp : Polynomial_Op<"apply_coefficientwise", [
Pure, SingleBlock]> {
Pure, SingleBlock, FixedFormCoeff]> {
let summary = "Apply a region to each coefficient of a polynomial.";
let description = [{
`polynomial.apply_coefficientwise` takes a polynomial and applies a series
Expand Down
102 changes: 67 additions & 35 deletions lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "lib/Dialect/Polynomial/IR/PolynomialOps.h"
#include "lib/Dialect/Polynomial/IR/PolynomialTraits.h"
#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "lib/Dialect/Polynomial/Transforms/NTTSolver.h"
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
Expand Down Expand Up @@ -57,21 +58,30 @@ enum class OpFormClass {
};

OpFormClass opFormClass(Operation* op) {
// This special case must come first because it overrides the more general
// SameOperandsAndResultForm trait for MulOp.
if (op->hasTrait<FixedFormEval>() || isa<MulOp>(op)) {
return OpFormClass::EVAL;
}

if (isa<func::ReturnOp>(op)) {
return OpFormClass::RETURN;
} else if (isa<ToTensorOp, LeadingTermOp, EvalOp, ConvertBasisOp,
MonicMonomialMulOp, FromTensorOp, ApplyCoefficientwiseOp>(
op)) {
}

if (op->hasTrait<FixedFormCoeff>() || isa<ToTensorOp, FromTensorOp>(op)) {
return OpFormClass::COEFF;
} else if (isa<MulOp>(op)) {
return OpFormClass::EVAL;
} else if (isa<AddOp, SubOp, MulScalarOp, ModSwitchOp, ExtractSliceOp,
tensor::ExtractSliceOp, tensor::ExtractOp,
tensor::FromElementsOp>(op)) {
}

if (op->hasTrait<SameOperandsAndResultForm>() ||
isa<tensor::ExtractSliceOp, tensor::ExtractOp, tensor::FromElementsOp>(
op)) {
return OpFormClass::EITHER;
} else if (isa<MonomialOp, ConstantOp>(op)) {
}

if (isa<MonomialOp, ConstantOp>(op)) {
return OpFormClass::CONST;
}

return OpFormClass::UNKNOWN;
}

Expand Down Expand Up @@ -244,34 +254,48 @@ void PolyMulToNTT::runOnOperation() {
// Eval poly inputs and outputs; this is really a mirror of the previous
// case
else if (opClass == OpFormClass::EVAL) {
Value y = polyResults[0];
// Since this op outputs eval form, the use of coeff form implies the
// use of eval form
solver.implyForm(y, Form::COEFF, Form::EVAL);
// There's a conversion cost if y_c is needed
solver.addConversionCostForForm(y, Form::COEFF);
for (Value x : polyOperands) {
// Use of output in eval form implies use of input in eval form
solver.implyUse(y, x, Form::EVAL);
if (polyResults.empty()) {
for (Value v : polyOperands) {
solver.forceDemandFixedForm(v, Form::EVAL);
}
} else {
Value y = polyResults[0];
// Since this op outputs eval form, the use of coeff form implies the
// use of eval form
solver.implyForm(y, Form::COEFF, Form::EVAL);
// There's a conversion cost if y_c is needed
solver.addConversionCostForForm(y, Form::COEFF);
for (Value x : polyOperands) {
// Use of output in eval form implies use of input in eval form
solver.implyUse(y, x, Form::EVAL);
}
}
}
// Ops that work in either form, as long as inputs and outputs are all
// "uni-form"
else if (opClass == OpFormClass::EITHER) {
Value y = polyResults[0];
// Since the value output by this op can be in either form, it gets a
// 'mode' variable. In short, if y_c is needed and y_e is not, we run the
// op in coeff mode, and vice versa.
solver.addOpMode(y);
for (Value x : polyOperands) {
// if y_mode = 0 and output (in either form) is needed, the inputs in
// coeff form are required if y_mode = 1 and output (in either form) is
// needed, the inputs in eval form are required
solver.implyMode(y, x);
if (polyResults.empty()) {
// If an op has no poly results, we don't have a mode variable to attach
// to it, so we just allow each operand to be in either form.
for (Value v : polyOperands) {
solver.forceDemandEitherForm(v);
}
} else {
Value y = polyResults[0];
// Since the value output by this op can be in either form, it gets a
// 'mode' variable. In short, if y_c is needed and y_e is not, we run
// the op in coeff mode, and vice versa.
solver.addOpMode(y);
for (Value x : polyOperands) {
// if y_mode = 0 and output (in either form) is needed, the inputs in
// coeff form are required if y_mode = 1 and output (in either form)
// is needed, the inputs in eval form are required
solver.implyMode(y, x);
}
// The only time there's a conversion cost is if both forms are needed.
// If only one form is needed, the op runs in that mode.
solver.addConversionCostIfBothForms(y);
}
// The only time there's a conversion cost is if both forms are needed. If
// only one form is needed, the op runs in that mode.
solver.addConversionCostIfBothForms(y);
}
// Ops that produce polynomials in any form. We can pre-compute these
// constants in either (or both!) form(s)
Expand Down Expand Up @@ -615,10 +639,18 @@ void PolyMulToNTT::runOnOperation() {
// Ops that work in either form, as long as inputs and outputs are all
// "uni-form"
else if (opClass == OpFormClass::EITHER) {
Value v = polyResults[0];
Form form = soln.getMode(v);
for (OpOperand* arg : polyOperands) {
arg->set(formToValue(arg->get(), form));
if (polyResults.empty()) {
for (OpOperand* arg : polyOperands) {
Value v = arg->get();
Form form = soln.needsForm(v, Form::COEFF) ? Form::COEFF : Form::EVAL;
arg->set(formToValue(v, form));
}
} else {
Value v = polyResults[0];
Form form = soln.getMode(v);
for (OpOperand* arg : polyOperands) {
arg->set(formToValue(arg->get(), form));
}
}
} else {
op->emitOpError(
Expand Down
Loading