diff --git a/WORKSPACE b/WORKSPACE index f566d7d09..956f16a4d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -101,8 +101,8 @@ go_repository( go_repository( name = "dev_cel_expr", importpath = "cel.dev/expr", - sum = "h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI=", - version = "v0.22.1", + sum = "h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg=", + version = "v0.23.1", ) # local_repository( @@ -153,7 +153,7 @@ go_repository( # of the above repositories but at different versions, so ours must come first. go_rules_dependencies() -go_register_toolchains(version = "1.21.1") +go_register_toolchains(version = "1.22.0") gazelle_dependencies() diff --git a/conformance/go.mod b/conformance/go.mod index 115630be9..92d2ca476 100644 --- a/conformance/go.mod +++ b/conformance/go.mod @@ -3,7 +3,7 @@ module github.com/google/cel-go/conformance go 1.22.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/bazelbuild/rules_go v0.49.0 github.com/google/cel-go v0.21.0 github.com/google/go-cmp v0.6.0 diff --git a/conformance/go.sum b/conformance/go.sum index 9544b17a9..95ac73a52 100644 --- a/conformance/go.sum +++ b/conformance/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/bazelbuild/rules_go v0.49.0 h1:5vCbuvy8Q11g41lseGJDc5vxhDjJtfxr6nM/IC4VmqM= diff --git a/go.mod b/go.mod index 9f089f4fd..914c1ec28 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.22.0 toolchain go1.23.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/stoewer/go-strcase v1.2.0 google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 diff --git a/go.sum b/go.sum index 062b316c3..23fe2170a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index f058f1ab9..15facc55d 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -63,6 +63,7 @@ go_test( embed = [":go_default_library"], deps = [ "//cel:go_default_library", + "//test:go_default_library", "//common/types:go_default_library", "//interpreter:go_default_library", "//common/types/ref:go_default_library", @@ -72,6 +73,6 @@ go_test( ) filegroup( - name = "k8s_policy_testdata", - srcs = glob(["testdata/k8s/*"]), -) \ No newline at end of file + name = "testdata", + srcs = glob(["testdata/**"]), +) diff --git a/policy/compiler_test.go b/policy/compiler_test.go index 545b01f8b..4bd1bfcce 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -243,24 +243,25 @@ func (r *runner) run(t *testing.T) { input := map[string]any{} var err error var activation interpreter.Activation - for k, v := range tc.Input { - if v.Expr != "" { - input[k] = r.eval(t, v.Expr) - continue + if tc.InputContext != nil && tc.InputContext.ContextExpr != "" { + ctxExpr := tc.InputContext.ContextExpr + ctx, err := r.eval(t, ctxExpr).ConvertToNative( + reflect.TypeOf(((*proto.Message)(nil))).Elem()) + if err != nil { + t.Fatalf("context variable is not a valid proto: %v", err) } - if v.ContextExpr != "" { - ctx, err := r.eval(t, v.ContextExpr).ConvertToNative( - reflect.TypeOf(((*proto.Message)(nil))).Elem()) - if err != nil { - t.Fatalf("context variable is not a valid proto: %v", err) - } - activation, err = cel.ContextProtoVars(ctx.(proto.Message)) - if err != nil { - t.Fatalf("cel.ContextProtoVars() failed: %v", err) + activation, err = cel.ContextProtoVars(ctx.(proto.Message)) + if err != nil { + t.Fatalf("cel.ContextProtoVars() failed: %v", err) + } + } else if len(tc.Input) != 0 { + for k, v := range tc.Input { + if v.Expr != "" { + input[k] = r.eval(t, v.Expr) + continue } - break + input[k] = v.Value } - input[k] = v.Value } if activation == nil { activation, err = interpreter.NewActivation(input) @@ -272,7 +273,12 @@ func (r *runner) run(t *testing.T) { if err != nil { t.Fatalf("prg.Eval(input) failed: %v", err) } - testOut := r.eval(t, tc.Output) + var testOut ref.Val + if tc.Output.Expr != "" { + testOut = r.eval(t, tc.Output.Expr) + } else if tc.Output.Value != nil { + testOut = r.env.CELTypeAdapter().NativeToValue(tc.Output.Value) + } if optOut, ok := out.(*types.Optional); ok { if optOut.Equal(types.OptionalNone) == types.True { if testOut.Equal(types.OptionalNone) != types.True { @@ -299,24 +305,25 @@ func (r *runner) bench(b *testing.B) { input := map[string]any{} var err error var activation interpreter.Activation - for k, v := range tc.Input { - if v.Expr != "" { - input[k] = r.eval(b, v.Expr) - continue + if tc.InputContext != nil && tc.InputContext.ContextExpr != "" { + ctxExpr := tc.InputContext.ContextExpr + ctx, err := r.eval(b, ctxExpr).ConvertToNative( + reflect.TypeOf(((*proto.Message)(nil))).Elem()) + if err != nil { + b.Fatalf("context variable is not a valid proto: %v", err) } - if v.ContextExpr != "" { - ctx, err := r.eval(b, v.ContextExpr).ConvertToNative( - reflect.TypeOf(((*proto.Message)(nil))).Elem()) - if err != nil { - b.Fatalf("context variable is not a valid proto: %v", err) - } - activation, err = cel.ContextProtoVars(ctx.(proto.Message)) - if err != nil { - b.Fatalf("cel.ContextProtoVars() failed: %v", err) + activation, err = cel.ContextProtoVars(ctx.(proto.Message)) + if err != nil { + b.Fatalf("cel.ContextProtoVars() failed: %v", err) + } + } else if tc.Input != nil { + for k, v := range tc.Input { + if v.Expr != "" { + input[k] = r.eval(b, v.Expr) + continue } - break + input[k] = v.Value } - input[k] = v.Value } if activation == nil { activation, err = interpreter.NewActivation(input) diff --git a/policy/conformance.go b/policy/conformance.go index 3d05f411c..160c5c87e 100644 --- a/policy/conformance.go +++ b/policy/conformance.go @@ -15,6 +15,8 @@ package policy // TestSuite describes a set of tests divided by section. +// +// Deprecated: Use google3/third_party/cel/go/test/suite.go instead. type TestSuite struct { Description string `yaml:"description"` Sections []*TestSection `yaml:"section"` diff --git a/policy/go.mod b/policy/go.mod index 410781f5d..400a3cac5 100644 --- a/policy/go.mod +++ b/policy/go.mod @@ -9,7 +9,7 @@ require ( ) require ( - cel.dev/expr v0.22.1 // indirect + cel.dev/expr v0.23.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect diff --git a/policy/go.sum b/policy/go.sum index 8b4ac4221..35ef4fed1 100644 --- a/policy/go.sum +++ b/policy/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/policy/helper_test.go b/policy/helper_test.go index 8e117331c..98ca4f7ab 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/cel-go/common/env" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/test" "gopkg.in/yaml.v3" @@ -447,13 +448,13 @@ func readPolicyConfig(t testing.TB, fileName string) *env.Config { return config } -func readTestSuite(t testing.TB, fileName string) *TestSuite { +func readTestSuite(t testing.TB, fileName string) *test.Suite { t.Helper() testCaseBytes, err := os.ReadFile(fileName) if err != nil { t.Fatalf("os.ReadFile(%s) failed: %v", fileName, err) } - suite := &TestSuite{} + suite := &test.Suite{} err = yaml.Unmarshal(testCaseBytes, suite) if err != nil { log.Fatalf("yaml.Unmarshal(%s) error: %v", fileName, err) diff --git a/policy/testdata/context_pb/config.yaml b/policy/testdata/context_pb/config.yaml index e7804575c..53ea95425 100644 --- a/policy/testdata/context_pb/config.yaml +++ b/policy/testdata/context_pb/config.yaml @@ -15,8 +15,6 @@ name: "context_pb" container: "google.expr.proto3" extensions: - - name: "optional" - version: "latest" - name: "strings" version: 2 context_variable: diff --git a/policy/testdata/context_pb/tests.yaml b/policy/testdata/context_pb/tests.yaml index 11e377e53..80849783b 100644 --- a/policy/testdata/context_pb/tests.yaml +++ b/policy/testdata/context_pb/tests.yaml @@ -16,18 +16,13 @@ description: "Protobuf input tests" section: - name: "valid" tests: - - name: "good spec" - input: - spec: - context_expr: > - test.TestAllTypes{single_int32: 10} - output: "optional.none()" + - name: "good spec" + context_expr: "test.TestAllTypes{single_int32: 10}" + output: + expr: "optional.none()" - name: "invalid" tests: - - name: "bad spec" - input: - spec: - context_expr: > - test.TestAllTypes{single_int32: 11} - output: > - "invalid spec, got single_int32=11, wanted <= 10" + - name: "bad spec" + context_expr: "test.TestAllTypes{single_int32: 11}" + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/policy/testdata/k8s/config.yaml b/policy/testdata/k8s/config.yaml index 5a2cb3290..15a32b535 100644 --- a/policy/testdata/k8s/config.yaml +++ b/policy/testdata/k8s/config.yaml @@ -14,10 +14,6 @@ name: k8s extensions: - - name: "optional" - version: "latest" - - name: "bindings" - version: "latest" - name: "strings" version: 2 variables: diff --git a/policy/testdata/k8s/tests.yaml b/policy/testdata/k8s/tests.yaml index 3965ea0f9..f3e7de790 100644 --- a/policy/testdata/k8s/tests.yaml +++ b/policy/testdata/k8s/tests.yaml @@ -28,4 +28,5 @@ section: - staging.dev.cel.container1 - staging.dev.cel.container2 - preprod.dev.cel.container3 - output: "'only staging containers are allowed in namespace dev.cel'" + output: + value: "only staging containers are allowed in namespace dev.cel" diff --git a/policy/testdata/limits/tests.yaml b/policy/testdata/limits/tests.yaml index 8f50a519d..88772e075 100644 --- a/policy/testdata/limits/tests.yaml +++ b/policy/testdata/limits/tests.yaml @@ -20,19 +20,23 @@ section: input: now: expr: "timestamp('2024-07-30T00:30:00Z')" - output: "'hello, me'" + output: + value: "hello, me" - name: "8pm" input: now: expr: "timestamp('2024-07-30T20:30:00Z')" - output: "'goodbye, me!'" + output: + value: "goodbye, me!" - name: "9pm" input: now: expr: "timestamp('2024-07-30T21:30:00Z')" - output: "'goodbye, me!!'" + output: + value: "goodbye, me!!" - name: "11pm" input: now: expr: "timestamp('2024-07-30T23:30:00Z')" - output: "'goodbye, me!!!'" + output: + value: "goodbye, me!!!" diff --git a/policy/testdata/nested_rule/tests.yaml b/policy/testdata/nested_rule/tests.yaml index 48101c89a..3f9f63437 100644 --- a/policy/testdata/nested_rule/tests.yaml +++ b/policy/testdata/nested_rule/tests.yaml @@ -21,13 +21,15 @@ section: resource: value: origin: "ir" - output: "{'banned': true}" + output: + expr: "{'banned': true}" - name: "by_default" input: resource: value: origin: "de" - output: "{'banned': true}" + output: + expr: "{'banned': true}" - name: "permitted" tests: - name: "valid_origin" @@ -35,4 +37,5 @@ section: resource: value: origin: "uk" - output: "{'banned': false}" + output: + expr: "{'banned': false}" diff --git a/policy/testdata/nested_rule2/tests.yaml b/policy/testdata/nested_rule2/tests.yaml index ac725956c..0e1a9ca69 100644 --- a/policy/testdata/nested_rule2/tests.yaml +++ b/policy/testdata/nested_rule2/tests.yaml @@ -22,21 +22,24 @@ section: value: user: "bad-user" origin: "ir" - output: "{'banned': 'restricted_region'}" + output: + expr: "{'banned': 'restricted_region'}" - name: "by_default" input: resource: value: user: "bad-user" origin: "de" - output: "{'banned': 'bad_actor'}" + output: + expr: "{'banned': 'bad_actor'}" - name: "unconfigured_region" input: resource: value: user: "good-user" origin: "de" - output: "{'banned': 'unconfigured_region'}" + output: + expr: "{'banned': 'unconfigured_region'}" - name: "permitted" tests: - name: "valid_origin" @@ -45,4 +48,5 @@ section: value: user: "good-user" origin: "uk" - output: "{}" + output: + expr: "{}" diff --git a/policy/testdata/nested_rule3/tests.yaml b/policy/testdata/nested_rule3/tests.yaml index ece86eba0..9d993c65f 100644 --- a/policy/testdata/nested_rule3/tests.yaml +++ b/policy/testdata/nested_rule3/tests.yaml @@ -22,21 +22,24 @@ section: value: user: "bad-user" origin: "ir" - output: "{'banned': 'restricted_region'}" + output: + expr: "{'banned': 'restricted_region'}" - name: "by_default" input: resource: value: user: "bad-user" origin: "de" - output: "{'banned': 'bad_actor'}" + output: + expr: "{'banned': 'bad_actor'}" - name: "unconfigured_region" input: resource: value: user: "good-user" origin: "de" - output: "{'banned': 'unconfigured_region'}" + output: + expr: "{'banned': 'unconfigured_region'}" - name: "permitted" tests: - name: "valid_origin" @@ -45,4 +48,5 @@ section: value: user: "good-user" origin: "uk" - output: "optional.none()" + output: + expr: "optional.none()" diff --git a/policy/testdata/nested_rule4/tests.yaml b/policy/testdata/nested_rule4/tests.yaml index a5af137f3..006eddb88 100644 --- a/policy/testdata/nested_rule4/tests.yaml +++ b/policy/testdata/nested_rule4/tests.yaml @@ -20,9 +20,11 @@ section: input: x: value: 0 - output: "false" + output: + value: false - name: "x=2" input: x: value: 2 - output: "true" + output: + value: true diff --git a/policy/testdata/nested_rule5/tests.yaml b/policy/testdata/nested_rule5/tests.yaml index 66cc44507..8cd794051 100644 --- a/policy/testdata/nested_rule5/tests.yaml +++ b/policy/testdata/nested_rule5/tests.yaml @@ -20,19 +20,23 @@ section: input: x: value: 0 - output: "false" + output: + value: false - name: "x=1" input: x: value: 1 - output: "optional.none()" + output: + expr: "optional.none()" - name: "x=2" input: x: value: 2 - output: "optional.none()" + output: + expr: "optional.none()" - name: "x=3" input: x: value: 3 - output: "true" + output: + value: true diff --git a/policy/testdata/nested_rule6/tests.yaml b/policy/testdata/nested_rule6/tests.yaml index dabce623c..fef586df0 100644 --- a/policy/testdata/nested_rule6/tests.yaml +++ b/policy/testdata/nested_rule6/tests.yaml @@ -20,4 +20,5 @@ section: input: x: value: 0 - output: "false" + output: + value: false diff --git a/policy/testdata/nested_rule7/tests.yaml b/policy/testdata/nested_rule7/tests.yaml index 7844e18f6..f740c7639 100644 --- a/policy/testdata/nested_rule7/tests.yaml +++ b/policy/testdata/nested_rule7/tests.yaml @@ -20,19 +20,23 @@ section: input: x: value: 1 - output: "optional.none()" + output: + expr: "optional.none()" - name: "x=2" input: x: value: 2 - output: "false" + output: + value: false - name: "x=3" input: x: value: 3 - output: "true" + output: + value: true - name: "x=4" input: x: value: 4 - output: "true" + output: + value: true diff --git a/policy/testdata/pb/tests.yaml b/policy/testdata/pb/tests.yaml index 770bcad09..a39f7b73f 100644 --- a/policy/testdata/pb/tests.yaml +++ b/policy/testdata/pb/tests.yaml @@ -21,7 +21,8 @@ section: spec: expr: > test.TestAllTypes{single_int32: 10} - output: "optional.none()" + output: + expr: "optional.none()" - name: "invalid" tests: - name: "bad spec" @@ -29,5 +30,5 @@ section: spec: expr: > test.TestAllTypes{single_int32: 11} - output: > - "invalid spec, got single_int32=11, wanted <= 10" + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/policy/testdata/required_labels/config.yaml b/policy/testdata/required_labels/config.yaml index f9081478a..c5c612e20 100644 --- a/policy/testdata/required_labels/config.yaml +++ b/policy/testdata/required_labels/config.yaml @@ -14,7 +14,6 @@ name: "labels" extensions: - - name: "bindings" - name: "strings" version: 2 - name: "two-var-comprehensions" diff --git a/policy/testdata/required_labels/tests.yaml b/policy/testdata/required_labels/tests.yaml index a4bf96dc2..2159b1b24 100644 --- a/policy/testdata/required_labels/tests.yaml +++ b/policy/testdata/required_labels/tests.yaml @@ -29,7 +29,8 @@ section: env: prod experiment: "group b" release: "v0.1.0" - output: "optional.none()" + output: + expr: "optional.none()" - name: "missing" tests: - name: "env" @@ -44,8 +45,8 @@ section: labels: experiment: "group b" release: "v0.1.0" - output: > - "missing one or more required labels: [\"env\"]" + output: + value: "missing one or more required labels: [\"env\"]" - name: "experiment" input: spec: @@ -58,8 +59,8 @@ section: labels: env: staging release: "v0.1.0" - output: > - "missing one or more required labels: [\"experiment\"]" + output: + value: "missing one or more required labels: [\"experiment\"]" - name: "invalid" tests: - name: "env" @@ -75,5 +76,5 @@ section: env: staging experiment: "group b" release: "v0.1.0" - output: > - "invalid values provided on one or more labels: [\"env\"]" + output: + value: "invalid values provided on one or more labels: [\"env\"]" diff --git a/policy/testdata/restricted_destinations/base_config.yaml b/policy/testdata/restricted_destinations/base_config.yaml index 2aae385ca..615a8b915 100644 --- a/policy/testdata/restricted_destinations/base_config.yaml +++ b/policy/testdata/restricted_destinations/base_config.yaml @@ -14,26 +14,26 @@ name: "labels" extensions: -- name: "lists" -- name: "sets" + - name: "lists" + - name: "sets" variables: -- name: "destination.ip" - type_name: "string" -- name: "origin.ip" - type_name: "string" -- name: "spec.restricted_destinations" - type_name: "list" - params: - - type_name: "string" -- name: "spec.origin" - type_name: "string" -- name: "request" - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" -- name: "resource" - type_name: "map" - params: - - type_name: "string" - - type_name: "dyn" + - name: "destination.ip" + type_name: "string" + - name: "origin.ip" + type_name: "string" + - name: "spec.restricted_destinations" + type_name: "list" + params: + - type_name: "string" + - name: "spec.origin" + type_name: "string" + - name: "request" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" + - name: "resource" + type_name: "map" + params: + - type_name: "string" + - type_name: "dyn" diff --git a/policy/testdata/restricted_destinations/tests.yaml b/policy/testdata/restricted_destinations/tests.yaml index 1cf59fe62..e448fb1a9 100644 --- a/policy/testdata/restricted_destinations/tests.yaml +++ b/policy/testdata/restricted_destinations/tests.yaml @@ -40,7 +40,8 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "us" - output: "false" # false means unrestricted + output: + value: false # false means unrestricted - name: "nationality_allowed" input: "spec.origin": @@ -64,7 +65,8 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "us" - output: "false" + output: + value: false - name: "invalid" tests: - name: "destination_ip_prohibited" @@ -91,7 +93,8 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "us" - output: "true" # true means restricted + output: + value: true # true means restricted - name: "resource_nationality_prohibited" input: "spec.origin": @@ -115,4 +118,5 @@ section: name: "/company/acme/secrets/doomsday-device" labels: location: "cu" - output: "true" + output: + value: true diff --git a/policy/testdata/unnest/tests.yaml b/policy/testdata/unnest/tests.yaml index 9bed7b352..31a8770d7 100644 --- a/policy/testdata/unnest/tests.yaml +++ b/policy/testdata/unnest/tests.yaml @@ -20,31 +20,33 @@ section: input: values: expr: "[4, 6]" - output: > - "some divisible by 2" + output: + value: "some divisible by 2" - name: "false" input: values: expr: "[1, 3, 5]" - output: "optional.none()" + output: + expr: "optional.none()" - name: "empty-set" input: values: expr: "[1, 2]" - output: "optional.none()" + output: + expr: "optional.none()" - name: "divisible by 4" tests: - name: "true" input: values: expr: "[4, 7]" - output: > - "at least one divisible by 4" + output: + value: "at least one divisible by 4" - name: "power of 6" tests: - name: "true" input: values: expr: "[6, 7]" - output: > - "at least one power of 6" + output: + value: "at least one power of 6" diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 59bd9a3dc..37b093a2a 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -9,6 +9,8 @@ package( "//interpreter:__subpackages__", "//parser:__subpackages__", "//server:__subpackages__", + "//tools:__subpackages__", + "//policy:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) @@ -18,6 +20,7 @@ go_library( srcs = [ "compare.go", "expr.go", + "suite.go", ], importpath = "github.com/google/cel-go/test", deps = [ diff --git a/test/suite.go b/test/suite.go new file mode 100644 index 000000000..2b499e45d --- /dev/null +++ b/test/suite.go @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +// Suite is a collection of tests designed to evaluate the correctness of +// a CEL policy or a CEL expression +type Suite struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Sections []*Section `yaml:"section"` +} + +// Section is a collection of related test cases. +type Section struct { + Name string `yaml:"name"` + Tests []*Case `yaml:"tests"` +} + +// Case is a test case to validate a CEL policy or expression. The test case +// encompasses evaluation of the compiled expression using the provided input +// bindings and asserting the result against the expected result. +type Case struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Input map[string]*InputValue `yaml:"input,omitempty"` + *InputContext `yaml:",inline,omitempty"` + Output *Output `yaml:"output"` +} + +// InputContext represents an optional context expression. +type InputContext struct { + ContextExpr string `yaml:"context_expr"` +} + +// InputValue represents an input value for a binding which can be either a simple literal value or +// an expression. +type InputValue struct { + Value any `yaml:"value"` + Expr string `yaml:"expr"` +} + +// Output represents the expected result of a test case. +type Output struct { + Value any `yaml:"value"` + Expr string `yaml:"expr"` + ErrorSet []string `yaml:"error_set"` + UnknownSet []int64 `yaml:"unknown_set"` +} diff --git a/tools/celtest/BUILD.bazel b/tools/celtest/BUILD.bazel new file mode 100644 index 000000000..295f4e755 --- /dev/null +++ b/tools/celtest/BUILD.bazel @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +go_library( + name = "go_default_library", + srcs = [ + "test_runner.go", + ], + importpath = "github.com/google/cel-go/tools/celtest", + deps = [ + "//cel:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//interpreter:go_default_library", + "//test:go_default_library", + "//tools/compiler:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@dev_cel_expr//:expr", + "@dev_cel_expr//conformance/test:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + "@io_bazel_rules_go//go/runfiles", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//reflect/protodesc:go_default_library", + "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", + "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", + "@org_golang_google_protobuf//testing/protocmp:go_default_library", + "@org_golang_google_protobuf//types/descriptorpb:go_default_library", + "@org_golang_google_protobuf//types/dynamicpb:go_default_library", + ], +) + +go_test( + name = "go_default_test", + size = "small", + srcs = [ + "test_runner_test.go", + ], + data = [ + ":testdata", + "//policy:testdata", + ], + embed = [":go_default_library"], + deps = [ + "//cel:go_default_library", + "//common/decls:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//policy:go_default_library", + "//tools/compiler:go_default_library", + "@in_gopkg_yaml_v3//:go_default_library", + ] +) + +filegroup( + name = "testdata", + srcs = glob(["testdata/**"]), +) \ No newline at end of file diff --git a/tools/celtest/test_runner.go b/tools/celtest/test_runner.go new file mode 100644 index 000000000..39f29dea3 --- /dev/null +++ b/tools/celtest/test_runner.go @@ -0,0 +1,707 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package celtest provides functions for testing CEL policies and expressions. +package celtest + +import ( + "flag" + "fmt" + "os" + "reflect" + "strings" + "testing" + + "gopkg.in/yaml.v3" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/test" + "github.com/google/cel-go/tools/compiler" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protocmp" + + celpb "cel.dev/expr" + conformancepb "cel.dev/expr/conformance/test" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + descpb "google.golang.org/protobuf/types/descriptorpb" + dynamicpb "google.golang.org/protobuf/types/dynamicpb" +) + +var ( + celExpression string + testSuitePath string + fileDescriptorSetPath string + configPath string + baseConfigPath string +) + +func init() { + flag.StringVar(&testSuitePath, "test_suite_path", "", "path to a test suite") + flag.StringVar(&fileDescriptorSetPath, "file_descriptor_set", "", "path to a file descriptor set") + flag.StringVar(&configPath, "config_path", "", "path to a config file") + flag.StringVar(&baseConfigPath, "base_config_path", "", "path to a base config file") + flag.StringVar(&celExpression, "cel_expr", "", "CEL expression to test") + flag.Parse() +} + +// TestRunnerOption is used to configure the following attributes of the Test Runner: +// - set the Compiler +// - add Input Expressions +// - set the test suite file path +// - set the test suite parser based on the file format: YAML or Textproto +type TestRunnerOption func(*TestRunner) (*TestRunner, error) + +// TriggerTests triggers tests for a CEL policy, expression or checked expression +// with the provided set of options. The options can be used to: +// - configure the Compiler used for parsing and compiling the expression +// - configure the Test Runner used for parsing and executing the tests +func TriggerTests(t *testing.T, testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) { + testRunnerOptions := testRunnerOptions(testRunnerOpts, testCompilerOpts...) + tr, err := NewTestRunner(testRunnerOptions...) + if err != nil { + t.Fatalf("error creating test runner: %v", err) + } + programs, err := tr.Programs(t) + if err != nil { + t.Fatalf("error creating programs: %v", err) + } + tests, err := tr.Tests(t) + if err != nil { + t.Fatalf("error creating tests: %v", err) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := tr.ExecuteTest(t, programs, test) + if err != nil { + t.Fatalf("error executing test: %v", err) + } + }) + } +} + +func testRunnerOptions(testRunnerOpts []TestRunnerOption, testCompilerOpts ...any) []TestRunnerOption { + compilerOpt := testRunnerCompilerFromFlags(testCompilerOpts...) + testSuiteParserOpt := DefaultTestSuiteParser(testSuitePath) + fileDescriptorSetOpt := AddFileDescriptorSet(fileDescriptorSetPath) + testRunnerExprOpt := testRunnerExpressionsFromFlags() + return append([]TestRunnerOption{compilerOpt, testSuiteParserOpt, fileDescriptorSetOpt, testRunnerExprOpt}, testRunnerOpts...) +} + +func testRunnerCompilerFromFlags(testCompilerOpts ...any) TestRunnerOption { + var opts []any + if fileDescriptorSetPath != "" { + opts = append(opts, compiler.TypeDescriptorSetFile(fileDescriptorSetPath)) + } + if baseConfigPath != "" { + opts = append(opts, compiler.EnvironmentFile(baseConfigPath)) + } + if configPath != "" { + opts = append(opts, compiler.EnvironmentFile(configPath)) + } + opts = append(opts, testCompilerOpts...) + return func(tr *TestRunner) (*TestRunner, error) { + c, err := compiler.NewCompiler(opts...) + if err != nil { + return nil, err + } + tr.Compiler = c + return tr, nil + } +} + +func testRunnerExpressionsFromFlags() TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if celExpression != "" { + tr.Expressions = append(tr.Expressions, &compiler.CompiledExpression{Path: celExpression}) + tr.Expressions = append(tr.Expressions, &compiler.FileExpression{Path: celExpression}) + tr.Expressions = append(tr.Expressions, &compiler.RawExpression{Value: celExpression}) + } + return tr, nil + } +} + +// TestSuiteParser is an interface for parsing a test suite: +// - ParseTextproto: Returns a cel.spec.expr.conformance.test.TestSuite message. +// - ParseYAML: Returns a test.Suite object. +// In case the test suite is serialized in a Textproto/YAML file, the path of the file is passed as +// an argument to the parse method. +type TestSuiteParser interface { + ParseTextproto(string) (*conformancepb.TestSuite, error) + ParseYAML(string) (*test.Suite, error) +} + +type tsParser struct { + TestSuiteParser +} + +// ParseTextproto parses a test suite file in Textproto format. +func (p *tsParser) ParseTextproto(path string) (*conformancepb.TestSuite, error) { + if path == "" { + return nil, nil + } + if fileFormat := compiler.InferFileFormat(path); fileFormat != compiler.TextProto { + return nil, fmt.Errorf("invalid file extension wanted: .textproto: found %v", fileFormat) + } + testSuite := &conformancepb.TestSuite{} + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("os.ReadFile(%q) failed: %v", path, err) + } + err = prototext.Unmarshal(data, testSuite) + return testSuite, err +} + +// ParseYAML parses a test suite file in YAML format. +func (p *tsParser) ParseYAML(path string) (*test.Suite, error) { + if path == "" { + return nil, nil + } + if fileFormat := compiler.InferFileFormat(path); fileFormat != compiler.TextYAML { + return nil, fmt.Errorf("invalid file extension wanted: .yaml: found %v", fileFormat) + } + testSuiteBytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("os.ReadFile(%q) failed: %v", path, err) + } + testSuite := &test.Suite{} + err = yaml.Unmarshal(testSuiteBytes, testSuite) + return testSuite, err +} + +// DefaultTestSuiteParser returns a TestRunnerOption which configures the test runner with a test suite parser. +func DefaultTestSuiteParser(path string) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if path == "" { + return tr, nil + } + tr.TestSuiteFilePath = path + tr.testSuiteParser = &tsParser{} + return tr, nil + } +} + +// TestRunner provides a structure to hold the different components required to execute tests for +// a list of Input Expressions. The TestRunner can be configured with the following options: +// - Compiler: The compiler used for parsing and compiling the input expressions. +// - Input Expressions: The list of input expressions to be tested. +// - Test Suite File Path: The path to the test suite file. +// - File Descriptor Set Path: The path to the file descriptor set file. +// - test Suite Parser: A parser for a test suite file serialized in Textproto/YAML format. +// +// The TestRunner provides the following methods: +// - Programs: Creates a list of CEL programs from the input expressions. +// - Tests: Creates a list of tests from the test suite file. +// - ExecuteTest: Executes a single +type TestRunner struct { + compiler.Compiler + Expressions []compiler.InputExpression + TestSuiteFilePath string + FileDescriptorSetPath string + testSuiteParser TestSuiteParser +} + +// Test represents a single test case to be executed. It encompasses the following: +// - name: The name of the test case. +// - input: The input to be used for evaluating the CEL expression. +// - resultMatcher: A function that takes in the result of evaluating the CEL expression and +// returns a TestResult. +type Test struct { + name string + input interpreter.Activation + resultMatcher func(ref.Val, error) TestResult +} + +// NewTest creates a new Test with the provided name, input and result matcher. +func NewTest(name string, input interpreter.Activation, resultMatcher func(ref.Val, error) TestResult) *Test { + return &Test{ + name: name, + input: input, + resultMatcher: resultMatcher, + } +} + +// TestResult represents the result of a test case execution. It contains the validation result +// along with the expected result and any errors encountered during the execution. +// - Success: Whether the result matcher condition validating the test case was satisfied. +// - Wanted: The expected result of the test case. +// - Error: Any error encountered during the execution. +type TestResult struct { + Success bool + Wanted string + Error error +} + +// NewTestRunner creates a Test Runner with the provided options. +// The options can be used to: +// - configure the Compiler used for parsing and compiling the input expressions +// - configure the Test Runner used for parsing and executing the tests +func NewTestRunner(opts ...TestRunnerOption) (*TestRunner, error) { + tr := &TestRunner{} + var err error + for _, opt := range opts { + tr, err = opt(tr) + if err != nil { + return nil, err + } + } + return tr, nil +} + +// AddFileDescriptorSet creates a Test Runner Option which adds a file descriptor set to the test +// runner. The file descriptor set is used to register proto messages in the global proto registry. +func AddFileDescriptorSet(path string) TestRunnerOption { + return func(tr *TestRunner) (*TestRunner, error) { + if path != "" { + tr.FileDescriptorSetPath = path + } + return tr, nil + } +} + +func registerMessages(path string) error { + if path == "" { + return nil + } + fds, err := fileDescriptorSet(path) + if err != nil { + return err + } + for _, file := range fds.GetFile() { + reflectFD, err := protodesc.NewFile(file, protoregistry.GlobalFiles) + if err != nil { + return fmt.Errorf("protodesc.NewFile(%q) failed: %v", file.GetName(), err) + } + if _, err := protoregistry.GlobalFiles.FindFileByPath(reflectFD.Path()); err == nil { + continue + } + err = protoregistry.GlobalFiles.RegisterFile(reflectFD) + if err != nil { + return fmt.Errorf("protoregistry.GlobalFiles.RegisterFile() failed: %v", err) + } + for i := 0; i < reflectFD.Messages().Len(); i++ { + msg := reflectFD.Messages().Get(i) + msgType := dynamicpb.NewMessageType(msg) + err = protoregistry.GlobalTypes.RegisterMessage(msgType) + if err != nil && !strings.Contains(err.Error(), "already registered") { + return fmt.Errorf("protoregistry.GlobalTypes.RegisterMessage(%q) failed: %v", msgType, err) + } + } + } + return nil +} + +func fileDescriptorSet(path string) (*descpb.FileDescriptorSet, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file descriptor set file %q: %v", fileDescriptorSetPath, err) + } + fds := &descpb.FileDescriptorSet{} + if err := proto.Unmarshal(bytes, fds); err != nil { + return nil, fmt.Errorf("failed to unmarshal file descriptor set file %q: %v", fileDescriptorSetPath, err) + } + return fds, nil +} + +// Programs creates a list of CEL programs from the input expressions configured in the test runner +// using the provided program options. +func (tr *TestRunner) Programs(t *testing.T, opts ...cel.ProgramOption) ([]cel.Program, error) { + t.Helper() + if tr.Compiler == nil { + return nil, fmt.Errorf("compiler is not set") + } + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + var programs []cel.Program + for _, expr := range tr.Expressions { + // TODO: propagate metadata map along with the program instance as a struct. + ast, _, err := expr.CreateAST(tr.Compiler) + if err != nil { + if strings.Contains(err.Error(), "invalid file extension") || + strings.Contains(err.Error(), "invalid raw expression") { + continue + } + return nil, err + } + prg, err := e.Program(ast, opts...) + if err != nil { + return nil, err + } + programs = append(programs, prg) + } + return programs, nil +} + +// Tests creates a list of tests from the test suite file and test suite parser configured in the +// test runner. +func (tr *TestRunner) Tests(t *testing.T) ([]*Test, error) { + if tr.Compiler == nil { + return nil, fmt.Errorf("compiler is not set") + } + if tr.testSuiteParser == nil { + return nil, fmt.Errorf("test suite parser is not set") + } + if testSuite, err := tr.testSuiteParser.ParseYAML(tr.TestSuiteFilePath); err != nil && + !strings.Contains(err.Error(), "invalid file extension") { + return nil, fmt.Errorf("tr.testSuiteParser.ParseYAML(%q) failed: %v", tr.TestSuiteFilePath, err) + } else if testSuite != nil { + return tr.createTestsFromYAML(t, testSuite) + } + err := registerMessages(tr.FileDescriptorSetPath) + if err != nil { + return nil, fmt.Errorf("registerMessages(%q) failed: %v", tr.FileDescriptorSetPath, err) + } + if testSuite, err := tr.testSuiteParser.ParseTextproto(tr.TestSuiteFilePath); err != nil && + !strings.Contains(err.Error(), "invalid file extension") { + return nil, fmt.Errorf("tr.testSuiteParser.ParseTextproto(%q) failed: %v", tr.TestSuiteFilePath, err) + } else if testSuite != nil { + return tr.createTestsFromTextproto(t, testSuite) + } + return nil, nil +} + +func (tr *TestRunner) createTestsFromTextproto(t *testing.T, testSuite *conformancepb.TestSuite) ([]*Test, error) { + var tests []*Test + for _, section := range testSuite.GetSections() { + sectionName := section.GetName() + for _, testCase := range section.GetTests() { + testName := fmt.Sprintf("%s/%s", sectionName, testCase.GetName()) + testInput, err := tr.createTestInputFromPB(t, testCase) + if err != nil { + return nil, err + } + testResultMatcher, err := tr.createResultMatcherFromPB(t, testCase) + if err != nil { + return nil, err + } + tests = append(tests, NewTest(testName, testInput, testResultMatcher)) + } + } + return tests, nil +} + +func (tr *TestRunner) createTestInputFromPB(t *testing.T, testCase *conformancepb.TestCase) (interpreter.Activation, error) { + t.Helper() + input := map[string]any{} + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + if testCase.GetInputContext() != nil { + if len(testCase.GetInput()) != 0 { + return nil, fmt.Errorf("only one of input and input_context can be provided at a time") + } + switch testInput := testCase.GetInputContext().GetInputContextKind().(type) { + case *conformancepb.InputContext_ContextExpr: + refVal, err := tr.eval(testInput.ContextExpr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", testInput.ContextExpr, err) + } + ctx, err := refVal.ConvertToNative( + reflect.TypeOf((*proto.Message)(nil)).Elem()) + if err != nil { + return nil, fmt.Errorf("context variable is not a valid proto: %w", err) + } + return cel.ContextProtoVars(ctx.(proto.Message)) + case *conformancepb.InputContext_ContextMessage: + refVal := e.CELTypeAdapter().NativeToValue(testInput.ContextMessage) + ctx, err := refVal.ConvertToNative(reflect.TypeOf((*proto.Message)(nil)).Elem()) + if err != nil { + return nil, fmt.Errorf("context variable is not a valid proto: %w", err) + } + return cel.ContextProtoVars(ctx.(proto.Message)) + } + } + for k, v := range testCase.GetInput() { + switch v.GetKind().(type) { + case *conformancepb.InputValue_Value: + input[k], err = cel.ProtoAsValue(e.CELTypeAdapter(), v.GetValue()) + if err != nil { + return nil, fmt.Errorf("cel.ProtoAsValue(%q) failed: %w", v, err) + } + case *conformancepb.InputValue_Expr: + input[k], err = tr.eval(v.GetExpr()) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", v.GetExpr(), err) + } + } + } + return interpreter.NewActivation(input) +} + +func (tr *TestRunner) createResultMatcherFromPB(t *testing.T, testCase *conformancepb.TestCase) (func(ref.Val, error) TestResult, error) { + t.Helper() + if testCase.GetOutput() == nil { + return nil, fmt.Errorf("expected output is nil") + } + successResult := TestResult{Success: true} + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + switch testOutput := testCase.GetOutput().GetResultKind().(type) { + case *conformancepb.TestOutput_ResultValue: + return func(val ref.Val, err error) TestResult { + want := e.CELTypeAdapter().NativeToValue(testOutput.ResultValue) + if err != nil { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: err} + } + outputVal, err := refValueToExprValue(val) + if err != nil { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: fmt.Errorf("refValueToExprValue(%q) failed: %v", val, err)} + } + testResultVal, err := canonicalValueToV1Alpha1(testOutput.ResultValue) + if err != nil { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: fmt.Errorf("canonicalValueToV1Alpha1(%q) failed: %v", testOutput.ResultValue, err)} + } + testVal := &exprpb.ExprValue{ + Kind: &exprpb.ExprValue_Value{Value: testResultVal}} + + if diff := cmp.Diff(testVal, outputVal, protocmp.Transform(), + protocmp.SortRepeatedFields(&exprpb.MapValue{}, "entries")); diff != "" { + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: fmt.Errorf("mismatched test output with diff (-want +got):\n%s", diff)} + } + return successResult + }, nil + case *conformancepb.TestOutput_ResultExpr: + return func(val ref.Val, err error) TestResult { + if err != nil { + return TestResult{Success: false, Error: err} + } + testOut, err := tr.eval(testOutput.ResultExpr) + if err != nil { + return TestResult{Success: false, Error: fmt.Errorf("eval(%q) failed: %v", testOutput.ResultExpr, err)} + } + if optOut, ok := val.(*types.Optional); ok { + if optOut.Equal(types.OptionalNone) == types.True { + if testOut.Equal(types.OptionalNone) != types.True { + return TestResult{Success: false, Wanted: fmt.Sprintf("optional value %v", testOut), Error: fmt.Errorf("policy eval got %v", val)} + } + } else if testOut.Equal(optOut.GetValue()) != types.True { + return TestResult{Success: false, Wanted: fmt.Sprintf("optional value %v", testOut), Error: fmt.Errorf("policy eval got %v", val)} + } + } else if val.Equal(testOut) != types.True { + return TestResult{Success: false, Wanted: fmt.Sprintf("optional value %v", testOut), Error: fmt.Errorf("policy eval got %v", val)} + } + return successResult + }, nil + case *conformancepb.TestOutput_EvalError: + return func(val ref.Val, err error) TestResult { + failureResult := TestResult{Success: false, Wanted: fmt.Sprintf("error %v", testOutput.EvalError)} + if err == nil { + return failureResult + } + // Compare the evaluated error with the expected error message only. + for _, want := range testOutput.EvalError.GetErrors() { + if strings.Contains(err.Error(), want.GetMessage()) { + return successResult + } + } + return failureResult + }, nil + case *conformancepb.TestOutput_Unknown: + // TODO: to implement + } + return nil, nil +} + +func refValueToExprValue(refVal ref.Val) (*exprpb.ExprValue, error) { + if types.IsUnknown(refVal) { + return &exprpb.ExprValue{ + Kind: &exprpb.ExprValue_Unknown{ + Unknown: &exprpb.UnknownSet{ + Exprs: refVal.Value().([]int64), + }, + }}, nil + } + v, err := cel.RefValueToValue(refVal) + if err != nil { + return nil, err + } + return &exprpb.ExprValue{ + Kind: &exprpb.ExprValue_Value{Value: v}}, nil +} + +func canonicalValueToV1Alpha1(val *celpb.Value) (*exprpb.Value, error) { + var v1val exprpb.Value + b, err := prototext.Marshal(val) + if err != nil { + return nil, err + } + if err := prototext.Unmarshal(b, &v1val); err != nil { + return nil, err + } + return &v1val, nil +} + +func (tr *TestRunner) eval(expr string) (ref.Val, error) { + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + e, err = e.Extend(cel.OptionalTypes()) + if err != nil { + return nil, fmt.Errorf("e.Extend() failed: %v", err) + } + ast, iss := e.Compile(expr) + if iss.Err() != nil { + return nil, fmt.Errorf("e.Compile(%q) failed: %v", expr, iss.Err()) + } + prg, err := e.Program(ast) + if err != nil { + return nil, fmt.Errorf("e.Program(%q) failed: %v", expr, err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + return nil, fmt.Errorf("prg.Eval(%q) failed: %v", expr, err) + } + return out, nil +} + +func (tr *TestRunner) createTestsFromYAML(t *testing.T, testSuite *test.Suite) ([]*Test, error) { + var tests []*Test + for _, section := range testSuite.Sections { + for _, testCase := range section.Tests { + testName := fmt.Sprintf("%s/%s", section.Name, testCase.Name) + testInput, err := tr.createTestInput(t, testCase) + if err != nil { + return nil, err + } + testResultMatcher, err := tr.createResultMatcher(t, testCase.Output) + if err != nil { + return nil, err + } + tests = append(tests, NewTest(testName, testInput, testResultMatcher)) + } + } + return tests, nil +} + +func (tr *TestRunner) createTestInput(t *testing.T, testCase *test.Case) (interpreter.Activation, error) { + t.Helper() + if testCase.InputContext != nil && testCase.InputContext.ContextExpr != "" { + if len(testCase.Input) != 0 { + return nil, fmt.Errorf("only one of input and input_context can be provided at a time") + } + contextExpr := testCase.InputContext.ContextExpr + out, err := tr.eval(contextExpr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", contextExpr, err) + } + ctx, err := out.ConvertToNative(reflect.TypeOf((*proto.Message)(nil)).Elem()) + if err != nil { + return nil, fmt.Errorf("context variable is not a valid proto: %w", err) + } + return cel.ContextProtoVars(ctx.(proto.Message)) + } + input := map[string]any{} + for k, v := range testCase.Input { + if v.Expr != "" { + val, err := tr.eval(v.Expr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", v.Expr, err) + } + input[k] = val + continue + } + input[k] = v.Value + } + return interpreter.NewActivation(input) +} + +func (tr *TestRunner) createResultMatcher(t *testing.T, testOutput *test.Output) (func(ref.Val, error) TestResult, error) { + t.Helper() + e, err := tr.CreateEnv() + if err != nil { + return nil, err + } + successResult := TestResult{Success: true} + if testOutput.Value != nil { + want := e.CELTypeAdapter().NativeToValue(testOutput.Value) + return func(out ref.Val, err error) TestResult { + if err == nil { + if out.Equal(want) == types.True { + return successResult + } + if optOut, ok := out.(*types.Optional); ok { + if optOut.HasValue() && optOut.GetValue().Equal(want) == types.True { + return successResult + } + } + } + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: err} + }, nil + } + if testOutput.Expr != "" { + want, err := tr.eval(testOutput.Expr) + if err != nil { + return nil, fmt.Errorf("eval(%q) failed: %w", testOutput.Expr, err) + } + return func(out ref.Val, err error) TestResult { + if err == nil { + if out.Equal(want) == types.True { + return successResult + } + if optOut, ok := out.(*types.Optional); ok { + if optOut.HasValue() && optOut.GetValue().Equal(want) == types.True { + return successResult + } + } + } + return TestResult{Success: false, Wanted: fmt.Sprintf("simple value %v", want), Error: err} + }, nil + } + if testOutput.ErrorSet != nil { + return func(out ref.Val, err error) TestResult { + failureResult := TestResult{Success: false, Wanted: fmt.Sprintf("error %v", testOutput.ErrorSet)} + if err == nil { + return failureResult + } + for _, want := range testOutput.ErrorSet { + if strings.Contains(err.Error(), want) { + return successResult + } + } + return failureResult + }, nil + } + if testOutput.UnknownSet != nil { + // TODO: to implement + } + return nil, nil +} + +// ExecuteTest executes the test case against the provided list of programs and returns an error if +// the test fails. +func (tr *TestRunner) ExecuteTest(t *testing.T, programs []cel.Program, test *Test) error { + t.Helper() + if tr.Compiler == nil { + return fmt.Errorf("compiler is not set") + } + for _, program := range programs { + out, _, err := program.Eval(test.input) + if testResult := test.resultMatcher(out, err); !testResult.Success { + return fmt.Errorf("test: %s \n wanted: %v \n failed: %v", test.name, testResult.Wanted, testResult.Error) + } + } + return nil +} diff --git a/tools/celtest/test_runner_test.go b/tools/celtest/test_runner_test.go new file mode 100644 index 000000000..3530c2c7c --- /dev/null +++ b/tools/celtest/test_runner_test.go @@ -0,0 +1,196 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package celtest provides functions for testing CEL policies and expressions. +package celtest + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/policy" + "github.com/google/cel-go/tools/compiler" + "gopkg.in/yaml.v3" +) + +type testCase struct { + name string + celExpression string + testSuitePath string + fileDescriptorSetPath string + configPath string + opts []any +} + +func setupTests() []*testCase { + testCases := []*testCase{ + { + name: "policy test with custom policy parser", + celExpression: "../../policy/testdata/k8s/policy.yaml", + testSuitePath: "../../policy/testdata/k8s/tests.yaml", + configPath: "../../policy/testdata/k8s/config.yaml", + opts: []any{k8sParserOpts()}, + }, + { + name: "policy test with function binding", + celExpression: "../../policy/testdata/restricted_destinations/policy.yaml", + testSuitePath: "../../policy/testdata/restricted_destinations/tests.yaml", + configPath: "../../policy/testdata/restricted_destinations/config.yaml", + opts: []any{locationCodeEnvOption()}, + }, + { + name: "policy test with custom policy metadata", + celExpression: "testdata/custom_policy.celpolicy", + testSuitePath: "testdata/custom_policy_tests.yaml", + opts: []any{customPolicyParserOption(), compiler.PolicyMetadataEnvOption(ParsePolicyVariables)}, + }, + { + name: "raw expression file test", + celExpression: "testdata/raw_expr.cel", + testSuitePath: "testdata/raw_expr_tests", + configPath: "testdata/config.yaml", + opts: []any{fnEnvOption()}, + }, + { + name: "raw expression test", + celExpression: "'i + fn(j) == 42'", + testSuitePath: "testdata/raw_expr_tests", + configPath: "testdata/config.yaml", + opts: []any{fnEnvOption()}, + }, + } + return testCases +} + +func locationCodeEnvOption() cel.EnvOption { + return cel.Function("locationCode", + cel.Overload("locationCode_string", []*cel.Type{cel.StringType}, cel.StringType, + cel.UnaryBinding(locationCode))) +} + +func locationCode(ip ref.Val) ref.Val { + switch ip.(types.String) { + case "10.0.0.1": + return types.String("us") + case "10.0.0.2": + return types.String("de") + default: + return types.String("ir") + } +} + +func k8sParserOpts() policy.ParserOption { + return func(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = policy.K8sTestTagHandler() + return p, nil + } +} + +// TestTriggerTestsCustomPolicy tests the TriggerTestsFromCompiler function for a custom policy +// by providing test runner and compiler options without setting the flag variables. +func TestTriggerTestsWithRunnerOptions(t *testing.T) { + t.Run("test trigger tests custom policy", func(t *testing.T) { + envOpt := compiler.EnvironmentFile("../../policy/testdata/k8s/config.yaml") + testSuiteParser := DefaultTestSuiteParser("../../policy/testdata/k8s/tests.yaml") + testCELPolicy := TestRunnerOption(func(tr *TestRunner) (*TestRunner, error) { + tr.Expressions = append(tr.Expressions, &compiler.FileExpression{ + Path: "../../policy/testdata/k8s/policy.yaml", + }) + return tr, nil + }) + c, err := compiler.NewCompiler(envOpt, k8sParserOpts()) + if err != nil { + t.Fatalf("compiler.NewCompiler() failed: %v", err) + } + compilerOpt := TestRunnerOption(func(tr *TestRunner) (*TestRunner, error) { + tr.Compiler = c + return tr, nil + }) + opts := []TestRunnerOption{compilerOpt, testSuiteParser, testCELPolicy} + TriggerTests(t, opts) + }) +} + +func customPolicyParserOption() policy.ParserOption { + return func(p *policy.Parser) (*policy.Parser, error) { + p.TagVisitor = customTagHandler{TagVisitor: policy.DefaultTagVisitor()} + return p, nil + } +} +func ParsePolicyVariables(metadata map[string]any) cel.EnvOption { + var variables []*decls.VariableDecl + for n, t := range metadata { + variables = append(variables, decls.NewVariable(n, parseCustomPolicyVariableType(t.(string)))) + } + return cel.VariableDecls(variables...) +} + +func parseCustomPolicyVariableType(t string) *types.Type { + switch t { + case "int": + return types.IntType + case "string": + return types.StringType + default: + return types.UnknownType + } +} + +type variableType struct { + VariableName string `yaml:"variable_name"` + VariableType string `yaml:"variable_type"` +} + +type customTagHandler struct { + policy.TagVisitor +} + +func (customTagHandler) PolicyTag(ctx policy.ParserContext, id int64, tagName string, node *yaml.Node, p *policy.Policy) { + switch tagName { + case "variable_types": + var varList []*variableType + if err := node.Decode(&varList); err != nil { + ctx.ReportErrorAtID(id, "invalid yaml variable_types node: %v, error: %w", node, err) + return + } + for _, v := range varList { + p.SetMetadata(v.VariableName, v.VariableType) + } + default: + ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) + } +} + +func fnEnvOption() cel.EnvOption { + return cel.Function("fn", + cel.Overload("fn_int", []*cel.Type{cel.IntType}, cel.IntType, + cel.UnaryBinding(func(in ref.Val) ref.Val { + i := in.(types.Int) + return i / types.Int(2) + }))) +} + +// TestTriggerTests tests different scenarios of the TriggerTestsFromCompiler function. +func TestTriggerTests(t *testing.T) { + for _, tc := range setupTests() { + celExpression = tc.celExpression + testSuitePath = tc.testSuitePath + configPath = tc.configPath + fileDescriptorSetPath = tc.fileDescriptorSetPath + TriggerTests(t, nil, tc.opts...) + } +} diff --git a/tools/compiler/testdata/custom_policy_config.yaml b/tools/celtest/testdata/config.yaml similarity index 83% rename from tools/compiler/testdata/custom_policy_config.yaml rename to tools/celtest/testdata/config.yaml index 7b54a43da..62abcb23e 100644 --- a/tools/compiler/testdata/custom_policy_config.yaml +++ b/tools/celtest/testdata/config.yaml @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: "custom_policy_config" -extensions: - - name: "optional" - version: "latest" +name: "simple expression config" +variables: + - name: "i" + type_name: "int" + - name: "j" + type_name: "int" diff --git a/tools/compiler/testdata/custom_policy.celpolicy b/tools/celtest/testdata/custom_policy.celpolicy similarity index 97% rename from tools/compiler/testdata/custom_policy.celpolicy rename to tools/celtest/testdata/custom_policy.celpolicy index 663fcf0a7..3867b26fe 100644 --- a/tools/compiler/testdata/custom_policy.celpolicy +++ b/tools/celtest/testdata/custom_policy.celpolicy @@ -23,3 +23,4 @@ rule: - condition: | variable1 == 1 || variable2 == "known" output: "true" + - output: "false" \ No newline at end of file diff --git a/tools/celtest/testdata/custom_policy_tests.yaml b/tools/celtest/testdata/custom_policy_tests.yaml new file mode 100644 index 000000000..f2b554f83 --- /dev/null +++ b/tools/celtest/testdata/custom_policy_tests.yaml @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Custom policy tests" +section: + - name: "output true" + tests: + - name: "variable 1 match" + input: + variable1: + value: 1 + output: + value: true + - name: "variable 2 match" + input: + variable1: + value: 2 + variable2: + value: "known" + output: + value: true + - name: "output false" + tests: + - name: "variable mismatch" + input: + variable1: + value: 2 + variable2: + value: "unknown" + output: + value: false diff --git a/tools/celtest/testdata/raw_expr.cel b/tools/celtest/testdata/raw_expr.cel new file mode 100644 index 000000000..63386498f --- /dev/null +++ b/tools/celtest/testdata/raw_expr.cel @@ -0,0 +1 @@ +"'i + fn(j) == 42'" \ No newline at end of file diff --git a/tools/celtest/testdata/raw_expr_tests.yaml b/tools/celtest/testdata/raw_expr_tests.yaml new file mode 100644 index 000000000..547d4319f --- /dev/null +++ b/tools/celtest/testdata/raw_expr_tests.yaml @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "simple expression tests" +section: + - name: "valid" + tests: + - name: "true" + input: + i: + value: 21 + j: + value: 42 + output: + value: true + - name: "false" + input: + i: + value: 22 + j: + value:42 + output: + value: false diff --git a/tools/compiler/BUILD.bazel b/tools/compiler/BUILD.bazel index 0c3e4080b..3b84df5ca 100644 --- a/tools/compiler/BUILD.bazel +++ b/tools/compiler/BUILD.bazel @@ -57,19 +57,16 @@ go_test( ], data = [ ":compiler_testdata", - "//policy:k8s_policy_testdata", + "//policy:testdata", ], embed = [":go_default_library"], deps = [ "//cel:go_default_library", - "//common/decls:go_default_library", "//common/env:go_default_library", - "//common/types:go_default_library", "//ext:go_default_library", "//policy:go_default_library", "@dev_cel_expr//:expr", "@dev_cel_expr//conformance:go_default_library", - "@in_gopkg_yaml_v3//:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library", ], ) diff --git a/tools/compiler/compiler.go b/tools/compiler/compiler.go index c2263e02f..272df2926 100644 --- a/tools/compiler/compiler.go +++ b/tools/compiler/compiler.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "path/filepath" + "sync" "gopkg.in/yaml.v3" @@ -57,6 +58,22 @@ const ( CELPolicy ) +// ExpressionType is an enum for the type of input expression. +type ExpressionType int + +const ( + // ExpressionTypeUnspecified is used when the expression type is not specified. + ExpressionTypeUnspecified ExpressionType = iota + // CompiledExpressionFile is file containing a checked expression. + CompiledExpressionFile + // PolicyFile is a file containing a CEL policy. + PolicyFile + // ExpressionFile is a file containing a CEL expression. + ExpressionFile + // RawExpressionString is a raw CEL expression string. + RawExpressionString +) + // PolicyMetadataEnvOption represents a function which accepts a policy metadata map and returns an // environment option used to extend the CEL environment. // @@ -90,6 +107,7 @@ type compiler struct { policyCompilerOptions []policy.CompilerOption policyMetadataEnvOptions []PolicyMetadataEnvOption env *cel.Env + doOnce sync.Once } // NewCompiler creates a new compiler with a set of functional options. @@ -114,20 +132,29 @@ func NewCompiler(opts ...any) (Compiler, error) { return nil, fmt.Errorf("unsupported compiler option: %v", opt) } } + c.envOptions = append(c.envOptions, extensionOpt()) return c, nil } +func extensionOpt() cel.EnvOption { + return func(e *cel.Env) (*cel.Env, error) { + envConfig := &env.Config{ + Extensions: []*env.Extension{ + &env.Extension{Name: "optional", Version: "latest"}, + &env.Extension{Name: "bindings", Version: "latest"}, + }, + } + return e.Extend(cel.FromConfig(envConfig, ext.ExtensionOptionFactory)) + } +} + // CreateEnv creates a singleton CEL environment with the configured environment options. func (c *compiler) CreateEnv() (*cel.Env, error) { - if c.env != nil { - return c.env, nil - } - env, err := cel.NewCustomEnv(c.envOptions...) - if err != nil { - return nil, err - } - c.env = env - return c.env, nil + var err error + c.doOnce.Do(func() { + c.env, err = cel.NewCustomEnv(c.envOptions...) + }) + return c.env, err } // CreatePolicyParser creates a policy parser using the optionally configured parser options. @@ -165,7 +192,8 @@ func loadProtoFile(path string, format FileFormat, out protoreflect.ProtoMessage return unmarshaller(data, out) } -func inferFileFormat(path string) FileFormat { +// InferFileFormat infers the file format from the file path. +func InferFileFormat(path string) FileFormat { extension := filepath.Ext(path) switch extension { case ".textproto": @@ -190,7 +218,7 @@ func inferFileFormat(path string) FileFormat { // - Binarypb func EnvironmentFile(path string) cel.EnvOption { return func(e *cel.Env) (*cel.Env, error) { - format := inferFileFormat(path) + format := InferFileFormat(path) if format != TextProto && format != TextYAML && format != BinaryProto { return nil, fmt.Errorf("file extension must be one of .textproto, .yaml, .binarypb: found %v", format) } @@ -403,7 +431,7 @@ func protoDeclToFunction(decl *celpb.Decl) (*env.Function, error) { // The file must be in binary format. func TypeDescriptorSetFile(path string) cel.EnvOption { return func(e *cel.Env) (*cel.Env, error) { - format := inferFileFormat(path) + format := InferFileFormat(path) if format != BinaryProto { return nil, fmt.Errorf("type descriptor must be in binary format") } @@ -438,9 +466,9 @@ type CompiledExpression struct { // - Textproto func (c *CompiledExpression) CreateAST(_ Compiler) (*cel.Ast, map[string]any, error) { var expr exprpb.CheckedExpr - format := inferFileFormat(c.Path) + format := InferFileFormat(c.Path) if format != BinaryProto && format != TextProto { - return nil, nil, fmt.Errorf("file extension must be .binarypb or .textproto: found %v", format) + return nil, nil, fmt.Errorf("invalid file extension wanted: .binarypb or .textproto found: %v", format) } if err := loadProtoFile(c.Path, format, &expr); err != nil { return nil, nil, err @@ -466,13 +494,13 @@ func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, if err != nil { return nil, nil, err } - data, err := loadFile(f.Path) - if err != nil { - return nil, nil, err - } - format := inferFileFormat(f.Path) + format := InferFileFormat(f.Path) switch format { case CELString: + data, err := loadFile(f.Path) + if err != nil { + return nil, nil, err + } src := common.NewStringSource(string(data), f.Path) ast, iss := e.CompileSource(src) if iss.Err() != nil { @@ -480,6 +508,10 @@ func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, } return ast, nil, nil case CELPolicy, TextYAML: + data, err := loadFile(f.Path) + if err != nil { + return nil, nil, err + } src := policy.ByteSource(data, f.Path) parser, err := compiler.CreatePolicyParser() if err != nil { @@ -501,7 +533,7 @@ func (f *FileExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, } return ast, policyMetadata, nil default: - return nil, nil, fmt.Errorf("unsupported file format: %v", format) + return nil, nil, fmt.Errorf("invalid file extension wanted: .cel or .celpolicy or .yaml found: %v", format) } } @@ -526,6 +558,10 @@ func (r *RawExpression) CreateAST(compiler Compiler) (*cel.Ast, map[string]any, if err != nil { return nil, nil, err } + format := InferFileFormat(r.Value) + if format != Unspecified { + return nil, nil, fmt.Errorf("invalid raw expression found file with extension: %v", format) + } ast, iss := e.Compile(r.Value) if iss.Err() != nil { return nil, nil, fmt.Errorf("e.Compile(%q) failed: %w", r.Value, iss.Err()) diff --git a/tools/compiler/compiler_test.go b/tools/compiler/compiler_test.go index d0c9ad0be..4bbad785e 100644 --- a/tools/compiler/compiler_test.go +++ b/tools/compiler/compiler_test.go @@ -19,12 +19,9 @@ import ( "testing" "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/env" - "github.com/google/cel-go/common/types" "github.com/google/cel-go/ext" "github.com/google/cel-go/policy" - "gopkg.in/yaml.v3" celpb "cel.dev/expr" configpb "cel.dev/expr/conformance" @@ -75,7 +72,7 @@ func TestEnvironmentFileCompareTextprotoAndYAML(t *testing.T) { for i, v := range protoConfig.Variables { for j, p := range v.TypeDesc.Params { if p.TypeName == "google.protobuf.Any" && - config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { + config.Variables[i].TypeDesc.Params[j].TypeName == "dyn" { p.TypeName = "dyn" } } @@ -186,10 +183,6 @@ func testEnvProto() *configpb.Environment { }, }, Extensions: []*configpb.Extension{ - { - Name: "optional", - Version: "latest", - }, { Name: "lists", Version: "latest", @@ -401,76 +394,6 @@ func TestFileExpressionCustomPolicyParser(t *testing.T) { }) } -func TestFileExpressionPolicyMetadataOptions(t *testing.T) { - t.Run("test file expression policy metadata options", func(t *testing.T) { - envOpt := EnvironmentFile("testdata/custom_policy_config.yaml") - parserOpt := policy.ParserOption(func(p *policy.Parser) (*policy.Parser, error) { - p.TagVisitor = customTagHandler{TagVisitor: policy.DefaultTagVisitor()} - return p, nil - }) - policyMetadataOpt := PolicyMetadataEnvOption(ParsePolicyVariables) - compilerOpts := []any{envOpt, parserOpt, policyMetadataOpt} - compiler, err := NewCompiler(compilerOpts...) - if err != nil { - t.Fatalf("NewCompiler() failed: %v", err) - } - policyFile := &FileExpression{ - Path: "testdata/custom_policy.celpolicy", - } - ast, _, err := policyFile.CreateAST(compiler) - if err != nil { - t.Fatalf("CreateAST() failed: %v", err) - } - if ast == nil { - t.Fatalf("CreateAST() returned nil ast") - } - }) -} - -func ParsePolicyVariables(metadata map[string]any) cel.EnvOption { - variables := []*decls.VariableDecl{} - for n, t := range metadata { - variables = append(variables, decls.NewVariable(n, parseCustomPolicyVariableType(t.(string)))) - } - return cel.VariableDecls(variables...) -} - -func parseCustomPolicyVariableType(t string) *types.Type { - switch t { - case "int": - return types.IntType - case "string": - return types.StringType - default: - return types.UnknownType - } -} - -type variableType struct { - VariableName string `yaml:"variable_name"` - VariableType string `yaml:"variable_type"` -} - -type customTagHandler struct { - policy.TagVisitor -} - -func (customTagHandler) PolicyTag(ctx policy.ParserContext, id int64, tagName string, node *yaml.Node, p *policy.Policy) { - switch tagName { - case "variable_types": - varList := []*variableType{} - if err := node.Decode(&varList); err != nil { - ctx.ReportErrorAtID(id, "invalid yaml variable_types node: %v, error: %w", node, err) - return - } - for _, v := range varList { - p.SetMetadata(v.VariableName, v.VariableType) - } - default: - ctx.ReportErrorAtID(id, "unsupported policy tag: %s", tagName) - } -} - func TestRawExpressionCreateAst(t *testing.T) { t.Run("test raw expression create ast", func(t *testing.T) { envOpt := EnvironmentFile("testdata/config.yaml") diff --git a/tools/compiler/testdata/config.yaml b/tools/compiler/testdata/config.yaml index 929427bc0..5ba153bbc 100644 --- a/tools/compiler/testdata/config.yaml +++ b/tools/compiler/testdata/config.yaml @@ -39,8 +39,6 @@ stdlib: return: type_name: "bool" extensions: - - name: "optional" - version: "latest" - name: "lists" version: "latest" - name: "sets" diff --git a/tools/go.mod b/tools/go.mod index 392efac8f..ea3a3dbf8 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -3,9 +3,10 @@ module github.com/google/cel-go/tools go 1.23.0 require ( - cel.dev/expr v0.22.1 + cel.dev/expr v0.23.1 github.com/google/cel-go v0.22.0 github.com/google/cel-go/policy v0.0.0-20250311174852-f5ea07b389a1 + github.com/google/go-cmp v0.6.0 google.golang.org/genproto/googleapis/api v0.0.0-20250311190419-81fb87f6b8bf google.golang.org/protobuf v1.36.5 gopkg.in/yaml.v3 v3.0.1 diff --git a/tools/go.sum b/tools/go.sum index b34becfc8..8e2a54ac3 100644 --- a/tools/go.sum +++ b/tools/go.sum @@ -1,5 +1,5 @@ -cel.dev/expr v0.22.1 h1:xoFEsNh972Yzey8N9TCPx2nDvMN7TMhQEzxLuj/iRrI= -cel.dev/expr v0.22.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= +cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/vendor/cel.dev/expr/MODULE.bazel b/vendor/cel.dev/expr/MODULE.bazel index c0a631316..85ac9ff61 100644 --- a/vendor/cel.dev/expr/MODULE.bazel +++ b/vendor/cel.dev/expr/MODULE.bazel @@ -8,7 +8,7 @@ bazel_dep( ) bazel_dep( name = "gazelle", - version = "0.36.0", + version = "0.39.1", repo_name = "bazel_gazelle", ) bazel_dep( @@ -35,11 +35,11 @@ bazel_dep( ) bazel_dep( name = "rules_cc", - version = "0.0.9", + version = "0.0.17", ) bazel_dep( name = "rules_go", - version = "0.50.1", + version = "0.53.0", repo_name = "io_bazel_rules_go", ) bazel_dep( @@ -48,7 +48,7 @@ bazel_dep( ) bazel_dep( name = "rules_proto", - version = "6.0.0", + version = "7.0.2", ) bazel_dep( name = "rules_python", @@ -63,7 +63,7 @@ python.toolchain( ) go_sdk = use_extension("@io_bazel_rules_go//go:extensions.bzl", "go_sdk") -go_sdk.download(version = "1.21.1") +go_sdk.download(version = "1.22.0") go_deps = use_extension("@bazel_gazelle//:extensions.bzl", "go_deps") go_deps.from_file(go_mod = "//:go.mod") diff --git a/vendor/modules.txt b/vendor/modules.txt index dfdf1bd13..a34dce8d0 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,5 +1,5 @@ -# cel.dev/expr v0.22.1 -## explicit; go 1.21.1 +# cel.dev/expr v0.23.1 +## explicit; go 1.22.0 cel.dev/expr # github.com/antlr4-go/antlr/v4 v4.13.0 ## explicit; go 1.20