From 9f8eb9b5852acd7df9d5229b02dfe2b8a9da4c15 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 6 Apr 2026 16:05:17 -0700 Subject: [PATCH] use form traits in PolyMulToNTT where possible PiperOrigin-RevId: 895546447 --- lib/Dialect/Polynomial/IR/PolynomialOps.td | 7 +- .../Polynomial/Transforms/PolyMulToNTT.cpp | 102 ++++++++++++------ 2 files changed, 70 insertions(+), 39 deletions(-) diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.td b/lib/Dialect/Polynomial/IR/PolynomialOps.td index d1d2874c41..ee11b8c069 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialOps.td +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.td @@ -199,7 +199,7 @@ def Polynomial_MonicMonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypes let results = (outs PolynomialLike:$output); } -def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> { +def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure, SameOperandsAndResultForm]> { let summary = "Creates a polynomial from integer coefficients or evaluations stored in a tensor."; let description = [{ `polynomial.from_tensor` creates a polynomial value from a tensor of @@ -236,7 +236,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> { let hasVerifier = 1; } -def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> { +def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure, SameOperandsAndResultForm]> { let summary = "Creates a tensor containing the coefficients or evaluations of a polynomial."; let description = [{ `polynomial.to_tensor` creates a dense tensor value containing the @@ -291,7 +291,6 @@ def Polynomial_ModSwitchOp : Polynomial_Op<"mod_switch", [Pure, SameOperandsAndR let hasVerifier = 1; } - def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[ Polynomial_TypedFloatPolynomialAttr, Polynomial_TypedIntPolynomialAttr, @@ -441,7 +440,7 @@ def Polynomial_YieldOp : Polynomial_Op<"yield", [Terminator, HasParent<"ApplyCoe } def Polynomial_ApplyCoefficientwiseOp : Polynomial_Op<"apply_coefficientwise", [ - Pure, SingleBlock]> { + Pure, SingleBlock, FixedFormCoeff]> { let summary = "Apply a region to each coefficient of a polynomial."; let description = [{ `polynomial.apply_coefficientwise` takes a polynomial and applies a series diff --git a/lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp b/lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp index 05c48a776c..d61f844ced 100644 --- a/lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp +++ b/lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp @@ -2,6 +2,7 @@ #include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "lib/Dialect/Polynomial/IR/PolynomialOps.h" +#include "lib/Dialect/Polynomial/IR/PolynomialTraits.h" #include "lib/Dialect/Polynomial/IR/PolynomialTypes.h" #include "lib/Dialect/Polynomial/Transforms/NTTSolver.h" #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project @@ -57,21 +58,30 @@ enum class OpFormClass { }; OpFormClass opFormClass(Operation* op) { + // This special case must come first because it overrides the more general + // SameOperandsAndResultForm trait for MulOp. + if (op->hasTrait() || isa(op)) { + return OpFormClass::EVAL; + } + if (isa(op)) { return OpFormClass::RETURN; - } else if (isa( - op)) { + } + + if (op->hasTrait() || isa(op)) { return OpFormClass::COEFF; - } else if (isa(op)) { - return OpFormClass::EVAL; - } else if (isa(op)) { + } + + if (op->hasTrait() || + isa( + op)) { return OpFormClass::EITHER; - } else if (isa(op)) { + } + + if (isa(op)) { return OpFormClass::CONST; } + return OpFormClass::UNKNOWN; } @@ -244,34 +254,48 @@ void PolyMulToNTT::runOnOperation() { // Eval poly inputs and outputs; this is really a mirror of the previous // case else if (opClass == OpFormClass::EVAL) { - Value y = polyResults[0]; - // Since this op outputs eval form, the use of coeff form implies the - // use of eval form - solver.implyForm(y, Form::COEFF, Form::EVAL); - // There's a conversion cost if y_c is needed - solver.addConversionCostForForm(y, Form::COEFF); - for (Value x : polyOperands) { - // Use of output in eval form implies use of input in eval form - solver.implyUse(y, x, Form::EVAL); + if (polyResults.empty()) { + for (Value v : polyOperands) { + solver.forceDemandFixedForm(v, Form::EVAL); + } + } else { + Value y = polyResults[0]; + // Since this op outputs eval form, the use of coeff form implies the + // use of eval form + solver.implyForm(y, Form::COEFF, Form::EVAL); + // There's a conversion cost if y_c is needed + solver.addConversionCostForForm(y, Form::COEFF); + for (Value x : polyOperands) { + // Use of output in eval form implies use of input in eval form + solver.implyUse(y, x, Form::EVAL); + } } } // Ops that work in either form, as long as inputs and outputs are all // "uni-form" else if (opClass == OpFormClass::EITHER) { - Value y = polyResults[0]; - // Since the value output by this op can be in either form, it gets a - // 'mode' variable. In short, if y_c is needed and y_e is not, we run the - // op in coeff mode, and vice versa. - solver.addOpMode(y); - for (Value x : polyOperands) { - // if y_mode = 0 and output (in either form) is needed, the inputs in - // coeff form are required if y_mode = 1 and output (in either form) is - // needed, the inputs in eval form are required - solver.implyMode(y, x); + if (polyResults.empty()) { + // If an op has no poly results, we don't have a mode variable to attach + // to it, so we just allow each operand to be in either form. + for (Value v : polyOperands) { + solver.forceDemandEitherForm(v); + } + } else { + Value y = polyResults[0]; + // Since the value output by this op can be in either form, it gets a + // 'mode' variable. In short, if y_c is needed and y_e is not, we run + // the op in coeff mode, and vice versa. + solver.addOpMode(y); + for (Value x : polyOperands) { + // if y_mode = 0 and output (in either form) is needed, the inputs in + // coeff form are required if y_mode = 1 and output (in either form) + // is needed, the inputs in eval form are required + solver.implyMode(y, x); + } + // The only time there's a conversion cost is if both forms are needed. + // If only one form is needed, the op runs in that mode. + solver.addConversionCostIfBothForms(y); } - // The only time there's a conversion cost is if both forms are needed. If - // only one form is needed, the op runs in that mode. - solver.addConversionCostIfBothForms(y); } // Ops that produce polynomials in any form. We can pre-compute these // constants in either (or both!) form(s) @@ -615,10 +639,18 @@ void PolyMulToNTT::runOnOperation() { // Ops that work in either form, as long as inputs and outputs are all // "uni-form" else if (opClass == OpFormClass::EITHER) { - Value v = polyResults[0]; - Form form = soln.getMode(v); - for (OpOperand* arg : polyOperands) { - arg->set(formToValue(arg->get(), form)); + if (polyResults.empty()) { + for (OpOperand* arg : polyOperands) { + Value v = arg->get(); + Form form = soln.needsForm(v, Form::COEFF) ? Form::COEFF : Form::EVAL; + arg->set(formToValue(v, form)); + } + } else { + Value v = polyResults[0]; + Form form = soln.getMode(v); + for (OpOperand* arg : polyOperands) { + arg->set(formToValue(arg->get(), form)); + } } } else { op->emitOpError(