diff --git a/common/env/env.go b/common/env/env.go index aa7b066c5..07294c696 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -115,7 +115,7 @@ func (c *Config) AddVariableDecls(vars ...*decls.VariableDecl) *Config { if v == nil { continue } - convVars[i] = NewVariable(v.Name(), serializeTypeDesc(v.Type())) + convVars[i] = NewVariable(v.Name(), SerializeTypeDesc(v.Type())) } return c.AddVariables(convVars...) } @@ -146,9 +146,9 @@ func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config { overloadID := o.ID() args := make([]*TypeDesc, 0, len(o.ArgTypes())) for _, a := range o.ArgTypes() { - args = append(args, serializeTypeDesc(a)) + args = append(args, SerializeTypeDesc(a)) } - ret := serializeTypeDesc(o.ResultType()) + ret := SerializeTypeDesc(o.ResultType()) if o.IsMemberFunction() { overloads = append(overloads, NewMemberOverload(overloadID, args[0], args[1:], ret)) } else { @@ -836,7 +836,8 @@ func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) { } } -func serializeTypeDesc(t *types.Type) *TypeDesc { +// SerializeTypeDesc converts *types.Type to a serialized format TypeDesc +func SerializeTypeDesc(t *types.Type) *TypeDesc { typeName := t.TypeName() if t.Kind() == types.TypeParamKind { return NewTypeParam(typeName) @@ -848,7 +849,7 @@ func serializeTypeDesc(t *types.Type) *TypeDesc { } var params []*TypeDesc for _, p := range t.Parameters() { - params = append(params, serializeTypeDesc(p)) + params = append(params, SerializeTypeDesc(p)) } return NewTypeDesc(typeName, params...) } diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index b764fa1f5..62863c17a 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "bindings.go", "comprehensions.go", "encoders.go", + "extension_option_factory.go", "formatting.go", "guards.go", "lists.go", @@ -26,6 +27,7 @@ go_library( "//checker:go_default_library", "//common/ast:go_default_library", "//common/decls:go_default_library", + "//common/env:go_default_library", "//common/overloads:go_default_library", "//common/operators:go_default_library", "//common/types:go_default_library", @@ -48,7 +50,8 @@ go_test( srcs = [ "bindings_test.go", "comprehensions_test.go", - "encoders_test.go", + "encoders_test.go", + "extension_option_factory_test.go", "lists_test.go", "math_test.go", "native_test.go", @@ -62,6 +65,7 @@ go_test( deps = [ "//cel:go_default_library", "//checker:go_default_library", + "//common/env:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", diff --git a/ext/extension_option_factory.go b/ext/extension_option_factory.go new file mode 100644 index 000000000..4906227a5 --- /dev/null +++ b/ext/extension_option_factory.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/env" +) + +// ExtensionOptionFactory converts an ExtensionConfig value to a CEL environment option. +func ExtensionOptionFactory(configElement any) (cel.EnvOption, bool) { + ext, isExtension := configElement.(*env.Extension) + if !isExtension { + return nil, false + } + fac, found := extFactories[ext.Name] + if !found { + return nil, false + } + // If the version is 'latest', set the version value to the max uint. + ver, err := ext.VersionNumber() + if err != nil { + return func(*cel.Env) (*cel.Env, error) { + return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) + }, true + } + return fac(ver), true +} + +// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. +type extensionFactory func(uint32) cel.EnvOption + +var extFactories = map[string]extensionFactory{ + "bindings": func(version uint32) cel.EnvOption { + return Bindings(BindingsVersion(version)) + }, + "encoders": func(version uint32) cel.EnvOption { + return Encoders(EncodersVersion(version)) + }, + "lists": func(version uint32) cel.EnvOption { + return Lists(ListsVersion(version)) + }, + "math": func(version uint32) cel.EnvOption { + return Math(MathVersion(version)) + }, + "protos": func(version uint32) cel.EnvOption { + return Protos(ProtosVersion(version)) + }, + "sets": func(version uint32) cel.EnvOption { + return Sets(SetsVersion(version)) + }, + "strings": func(version uint32) cel.EnvOption { + return Strings(StringsVersion(version)) + }, + "two-var-comprehensions": func(version uint32) cel.EnvOption { + return TwoVarComprehensions(TwoVarComprehensionsVersion(version)) + }, +} diff --git a/ext/extension_option_factory_test.go b/ext/extension_option_factory_test.go new file mode 100644 index 000000000..f721bb6bf --- /dev/null +++ b/ext/extension_option_factory_test.go @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "fmt" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/env" +) + +func TestExtensionOptionFactoryInvalidExtension(t *testing.T) { + invalidExtension := "invalid extension" + _, validExtension := ExtensionOptionFactory(invalidExtension) + if validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned valid extension for invalid input", invalidExtension) + } +} + +func TestExtensionOptionFactoryInvalidExtensionName(t *testing.T) { + e := &env.Extension{Name: "invalid extension name"} + _, validExtension := ExtensionOptionFactory(e) + if validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned valid extension for invalid extension name", e.Name) + } +} + +func TestExtensionOptionFactoryInvalidExtensionVersion(t *testing.T) { + e := &env.Extension{Name: "bindings", Version: "invalid version"} + opt, validExtension := ExtensionOptionFactory(e) + if !validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } + _, err := cel.NewCustomEnv(opt) + if err == nil || err.Error() != fmt.Sprintf("invalid extension version: %s - %s", e.Name, e.Version) { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension version", e.Name) + } +} + +func TestExtensionOptionFactoryValidBindingsExtension(t *testing.T) { + e := &env.Extension{Name: "bindings", Version: "latest"} + opt, validExtension := ExtensionOptionFactory(e) + if !validExtension { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } + en, err := cel.NewCustomEnv(opt) + if err != nil { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } + cfg, err := en.ToConfig("test config") + if len(cfg.Extensions) != 1 || cfg.Extensions[0].Name != "cel.lib.ext.cel.bindings" || cfg.Extensions[0].Version != "latest" { + t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name) + } +} diff --git a/policy/config.go b/policy/config.go index 12fbe44a7..02243922b 100644 --- a/policy/config.go +++ b/policy/config.go @@ -15,8 +15,6 @@ package policy import ( - "fmt" - "github.com/google/cel-go/cel" "github.com/google/cel-go/common/env" "github.com/google/cel-go/ext" @@ -28,55 +26,5 @@ import ( // a set of configuration ConfigOptionFactory values to handle extensions and other config features // which may be defined outside of the `cel` package. func FromConfig(config *env.Config) cel.EnvOption { - return cel.FromConfig(config, extensionOptionFactory) -} - -// extensionOptionFactory converts an ExtensionConfig value to a CEL environment option. -func extensionOptionFactory(configElement any) (cel.EnvOption, bool) { - ext, isExtension := configElement.(*env.Extension) - if !isExtension { - return nil, false - } - fac, found := extFactories[ext.Name] - if !found { - return nil, false - } - // If the version is 'latest', set the version value to the max uint. - ver, err := ext.VersionNumber() - if err != nil { - return func(*cel.Env) (*cel.Env, error) { - return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version) - }, true - } - return fac(ver), true -} - -// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension. -type extensionFactory func(uint32) cel.EnvOption - -var extFactories = map[string]extensionFactory{ - "bindings": func(version uint32) cel.EnvOption { - return ext.Bindings(ext.BindingsVersion(version)) - }, - "encoders": func(version uint32) cel.EnvOption { - return ext.Encoders(ext.EncodersVersion(version)) - }, - "lists": func(version uint32) cel.EnvOption { - return ext.Lists(ext.ListsVersion(version)) - }, - "math": func(version uint32) cel.EnvOption { - return ext.Math(ext.MathVersion(version)) - }, - "protos": func(version uint32) cel.EnvOption { - return ext.Protos(ext.ProtosVersion(version)) - }, - "sets": func(version uint32) cel.EnvOption { - return ext.Sets(ext.SetsVersion(version)) - }, - "strings": func(version uint32) cel.EnvOption { - return ext.Strings(ext.StringsVersion(version)) - }, - "two-var-comprehensions": func(version uint32) cel.EnvOption { - return ext.TwoVarComprehensions(ext.TwoVarComprehensionsVersion(version)) - }, + return cel.FromConfig(config, ext.ExtensionOptionFactory) }