Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions cel/folding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,10 @@ func TestConstantFoldingOptimizer(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
opt, err := NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand Down Expand Up @@ -441,7 +444,10 @@ func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
opt, err := NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if tc.error != "" {
if iss.Err() == nil {
Expand Down Expand Up @@ -508,7 +514,10 @@ func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
opt, err := NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand Down Expand Up @@ -570,7 +579,10 @@ func TestConstantFoldingOptimizerWithLimit(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
opt, err := NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand Down Expand Up @@ -828,7 +840,10 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
opt, err := NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand Down
20 changes: 16 additions & 4 deletions cel/inlining_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ func TestInliningOptimizer(t *testing.T) {
t.Fatalf("Compile() failed: %v", iss.Err())
}

opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand All @@ -236,7 +239,10 @@ func TestInliningOptimizer(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt = cel.NewStaticOptimizer(folder)
opt, err = cel.NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss = opt.Optimize(e, optimized)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand Down Expand Up @@ -727,7 +733,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) {
t.Fatalf("Compile() failed: %v", iss.Err())
}

opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
opt, err := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...))
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss := opt.Optimize(e, checked)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand All @@ -743,7 +752,10 @@ func TestInliningOptimizerMultiStage(t *testing.T) {
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt = cel.NewStaticOptimizer(folder)
opt, err = cel.NewStaticOptimizer(folder)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optimized, iss = opt.Optimize(e, optimized)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand Down
54 changes: 46 additions & 8 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cel

import (
"fmt"
"sort"

"github.com/google/cel-go/common"
Expand All @@ -29,17 +30,43 @@ import (
// passes to ensure that the final optimized output is a valid expression with metadata consistent
// with what would have been generated from a parsed and checked expression.
//
// Note: source position information is best-effort and likely wrong, but optimized expressions
// Note: source position information is best-effort and incomplete, but optimized expressions
// should be suitable for calls to parser.Unparse.
type StaticOptimizer struct {
optimizers []ASTOptimizer
// If set, Optimize() will use this Source instead of the one from the AST.
sourceOverride *Source
}

type OptimizerOption func(*StaticOptimizer) (*StaticOptimizer, error)

// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied
// to a checked expression.
func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
return &StaticOptimizer{
optimizers: optimizers,
func NewStaticOptimizer(options ...any) (*StaticOptimizer, error) {
so := &StaticOptimizer{}
var err error
for _, opt := range options {
switch v := opt.(type) {
case ASTOptimizer:
so.optimizers = append(so.optimizers, v)
case OptimizerOption:
so, err = v(so)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported option: %v", v)
}
}
return so, nil
}

// OptimizeWithSource overrides the source used by the optimizer.
// Note this will cause the source info from the AST passed to Optimize() to be discarded.
func OptimizeWithSource(source Source) OptimizerOption {
return func(so *StaticOptimizer) (*StaticOptimizer, error) {
so.sourceOverride = &source
return so, nil
}
}

Expand All @@ -49,15 +76,21 @@ func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.NativeRep())
source := a.Source()
sourceInfo := optimized.SourceInfo()
if opt.sourceOverride != nil {
source = *opt.sourceOverride
sourceInfo = ast.NewSourceInfo(*opt.sourceOverride)
}
ids := newIDGenerator(ast.MaxID(a.NativeRep()))

// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
issues := NewIssues(common.NewErrors(source))
baseFac := ast.NewExprFactory()
exprFac := &optimizerExprFactory{
idGenerator: ids,
fac: baseFac,
sourceInfo: optimized.SourceInfo(),
sourceInfo: sourceInfo,
}
ctx := &OptimizerContext{
optimizerExprFactory: exprFac,
Expand All @@ -80,7 +113,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {

// Recheck the updated expression for any possible type-agreement or validation errors.
parsed := &Ast{
source: a.Source(),
source: source,
impl: ast.NewAST(expr, info)}
checked, iss := ctx.Check(parsed)
if iss.Err() != nil {
Expand All @@ -91,7 +124,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {

// Return the optimized result.
return &Ast{
source: a.Source(),
source: source,
impl: optimized,
}, nil
}
Expand All @@ -100,6 +133,8 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
optimized.RenumberIDs(idGen)
info.RenumberIDs(idGen)

if len(info.MacroCalls()) == 0 {
return
}
Expand Down Expand Up @@ -260,6 +295,9 @@ func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr {
for macroID, call := range copyInfo.MacroCalls() {
opt.SetMacroCall(macroID, call)
}
for id, offset := range copyInfo.OffsetRanges() {
opt.sourceInfo.SetOffsetRange(id, offset)
}
return copyExpr
}

Expand Down
98 changes: 82 additions & 16 deletions cel/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
opt := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()})
opt, err := cel.NewStaticOptimizer(&testOptimizer{t: t, inlineExpr: inlinedAST.NativeRep()})
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optAST, iss := opt.Optimize(e, exprAST)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand All @@ -59,28 +62,17 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
if err != nil {
t.Fatalf("cel.AstToCheckedExpr() failed: %v", err)
}
sourceInfoPB.Positions = nil
wantTextPB := `
location: "<input>"
line_offsets: 9
positions: {
key: 2
value: 4
}
positions: {
key: 3
value: 5
}
positions: {
key: 4
value: 3
}
macro_calls: {
key: 1
value: {
call_expr: {
function: "has"
args: {
id: 21
id: 24
select_expr: {
operand: {
id: 2
Expand Down Expand Up @@ -186,7 +178,10 @@ func TestStaticOptimizerNewAST(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("Compile(%q) failed: %v", tc, iss.Err())
}
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t})
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optAST, iss := opt.Optimize(e, exprAST)
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
Expand All @@ -202,9 +197,69 @@ func TestStaticOptimizerNewAST(t *testing.T) {
}
}

func TestOptimizeWithSource(t *testing.T) {
initial := `has(a.b)`
replacement := `x["a"]`
e := optimizerEnv(t)
initialAST, iss := e.Compile(initial)
if iss.Err() != nil {
t.Fatalf("Compile(%q) failed: %v", initial, iss.Err())
}
replacementAST, iss := e.Compile(replacement)
if iss.Err() != nil {
t.Fatalf("Compile(%q) failed: %v", replacement, iss.Err())
}

opt, err := cel.NewStaticOptimizer(
&replaceOptimizer{t: t, targetAST: replacementAST.NativeRep()},
cel.OptimizeWithSource(replacementAST.Source()),
)
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optAST, iss := opt.Optimize(e, initialAST)
if iss.Err() != nil {
t.Fatalf("Optimize() returned an error: %v", iss.Err())
}

if optAST.Source().Content() != replacement {
t.Errorf("got source content %q, wanted %q", optAST.Source().Content(), replacement)
}
sourceInfoPB, err := ast.SourceInfoToProto(optAST.NativeRep().SourceInfo())
if err != nil {
t.Fatalf("cel.AstToCheckedExpr() failed: %v", err)
}
wantTextPB := `
location: "<input>"
line_offsets: 7
positions: {
key: 1
value: 1
}
positions: {
key: 2
value: 0
}
positions: {
key: 3
value: 2
}
`
var wantSourceInfoPB exprpb.SourceInfo
if err := prototext.Unmarshal([]byte(wantTextPB), &wantSourceInfoPB); err != nil {
t.Fatalf("prototext.Unmarshal() failed: %v", err)
}
if !proto.Equal(&wantSourceInfoPB, sourceInfoPB) {
t.Errorf("got source info: %s, wanted %s", prototext.Format(sourceInfoPB), wantTextPB)
}
}

func TestStaticOptimizerNilAST(t *testing.T) {
env := optimizerEnv(t)
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
opt, err := cel.NewStaticOptimizer(&identityOptimizer{t: t})
if err != nil {
t.Fatalf("NewStaticOptimizer() failed: %v", err)
}
optAST, iss := opt.Optimize(env, nil)
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") {
t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss)
Expand Down Expand Up @@ -245,6 +300,17 @@ func (opt *testOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.A
return ctx.NewAST(a.Expr())
}

type replaceOptimizer struct {
t *testing.T
targetAST *ast.AST
}

func (opt *replaceOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
opt.t.Helper()
copy := ctx.CopyASTAndMetadata(opt.targetAST)
return ctx.NewAST(copy)
}

func getMacroKeys(macroCalls map[int64]ast.Expr) []int {
keys := []int{}
for k := range macroCalls {
Expand Down
Loading