Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cel
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -239,6 +240,43 @@ func TestAbbrevsDisambiguation(t *testing.T) {
}
}

func TestConvertToNativeJSONStructure(t *testing.T) {
env, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, issues := env.Compile(`{
"parts": [{"kind": "text"}]
}`)
if issues != nil && issues.Err() != nil {
t.Fatal(issues.Err())
}

prg, err := env.Program(ast)
if err != nil {
t.Fatal(err)
}

result, _, err := prg.Eval(map[string]any{})
if err != nil {
t.Fatal(err)
}

native, err := result.ConvertToNative(types.JSONValueType)
if err != nil {
t.Fatal(err)
}

jsonBytes, err := json.Marshal(native)
if err != nil {
t.Fatalf("json.Marshal failed: %v", err)
}
want := `{"parts":[{"kind":"text"}]}`
if string(jsonBytes) != want {
t.Errorf("json.Marshal() failed, got : %s, wanted ", jsonBytes)
}
}

func TestCustomEnvError(t *testing.T) {
env, err := NewCustomEnv(StdLib(), StdLib())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion common/types/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (b Bool) ConvertToNative(typeDesc reflect.Type) (any, error) {
case boolWrapperType:
// Convert the bool to a wrapperspb.BoolValue.
return wrapperspb.Bool(bool(b)), nil
case jsonValueType:
case JSONValueType:
// Return the bool as a new structpb.Value.
return structpb.NewBoolValue(bool(b)), nil
default:
Expand Down
2 changes: 1 addition & 1 deletion common/types/bool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestBoolConvertToNative_Error(t *testing.T) {
}

func TestBoolConvertToNative_Json(t *testing.T) {
val, err := True.ConvertToNative(jsonValueType)
val, err := True.ConvertToNative(JSONValueType)
pbVal := &structpb.Value{Kind: &structpb.Value_BoolValue{BoolValue: true}}
if err != nil {
t.Error(err)
Expand Down
2 changes: 1 addition & 1 deletion common/types/bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) {
case byteWrapperType:
// Convert the bytes to a wrapperspb.BytesValue.
return wrapperspb.Bytes([]byte(b)), nil
case jsonValueType:
case JSONValueType:
// CEL follows the proto3 to JSON conversion by encoding bytes to a string via base64.
// The encoding below matches the golang 'encoding/json' behavior during marshaling,
// which uses base64.StdEncoding.
Expand Down
2 changes: 1 addition & 1 deletion common/types/bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestBytesConvertToNative_Error(t *testing.T) {
}

func TestBytesConvertToNative_Json(t *testing.T) {
val, err := Bytes("123").ConvertToNative(jsonValueType)
val, err := Bytes("123").ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
}
Expand Down
2 changes: 1 addition & 1 deletion common/types/double.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (d Double) ConvertToNative(typeDesc reflect.Type) (any, error) {
case floatWrapperType:
// Convert to a wrapperspb.FloatValue (with truncation).
return wrapperspb.Float(float32(d)), nil
case jsonValueType:
case JSONValueType:
// Note, there are special cases for proto3 to json conversion that
// expect the floating point value to be converted to a NaN,
// Infinity, or -Infinity string values, but the jsonpb string
Expand Down
8 changes: 4 additions & 4 deletions common/types/double_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ func TestDoubleConvertToNative_Float64(t *testing.T) {
}

func TestDoubleConvertToNative_Json(t *testing.T) {
val, err := Double(-1.4).ConvertToNative(jsonValueType)
val, err := Double(-1.4).ConvertToNative(JSONValueType)
pbVal := structpb.NewNumberValue(-1.4)
if err != nil {
t.Error(err)
} else if !proto.Equal(val.(proto.Message), pbVal) {
t.Errorf("Got '%v', expected -1.4", val)
}

val, err = Double(math.NaN()).ConvertToNative(jsonValueType)
val, err = Double(math.NaN()).ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
} else {
Expand All @@ -189,14 +189,14 @@ func TestDoubleConvertToNative_Json(t *testing.T) {
}
}

val, err = Double(math.Inf(-1)).ConvertToNative(jsonValueType)
val, err = Double(math.Inf(-1)).ConvertToNative(JSONValueType)
pbVal = structpb.NewNumberValue(math.Inf(-1))
if err != nil {
t.Error(err)
} else if !proto.Equal(val.(proto.Message), pbVal) {
t.Errorf("Got '%v', expected -Infinity", val)
}
val, err = Double(math.Inf(0)).ConvertToNative(jsonValueType)
val, err = Double(math.Inf(0)).ConvertToNative(JSONValueType)
pbVal = structpb.NewNumberValue(math.Inf(0))
if err != nil {
t.Error(err)
Expand Down
2 changes: 1 addition & 1 deletion common/types/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (d Duration) ConvertToNative(typeDesc reflect.Type) (any, error) {
case durationValueType:
// Unwrap the CEL value to its underlying proto value.
return dpb.New(d.Duration), nil
case jsonValueType:
case JSONValueType:
// CEL follows the proto3 to JSON conversion.
// Note, using jsonpb would wrap the result in extra double quotes.
v := d.ConvertToType(StringType)
Expand Down
4 changes: 2 additions & 2 deletions common/types/duration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestDurationConvertToNative_Any(t *testing.T) {
}

func TestDurationConvertToNative_Error(t *testing.T) {
val, err := Duration{Duration: duration(7506, 1000)}.ConvertToNative(jsonValueType)
val, err := Duration{Duration: duration(7506, 1000)}.ConvertToNative(JSONValueType)
if err != nil {
t.Errorf("Got error: '%v', expected value", err)
}
Expand All @@ -178,7 +178,7 @@ func TestDurationConvertToNative_Error(t *testing.T) {
}

func TestDurationConvertToNative_Json(t *testing.T) {
val, err := Duration{Duration: duration(7506, 1000)}.ConvertToNative(jsonValueType)
val, err := Duration{Duration: duration(7506, 1000)}.ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
}
Expand Down
2 changes: 1 addition & 1 deletion common/types/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
case int64WrapperType:
// Convert the value to a wrapperspb.Int64Value.
return wrapperspb.Int64(int64(i)), nil
case jsonValueType:
case JSONValueType:
// The proto-to-JSON conversion rules would convert all 64-bit integer values to JSON
// decimal strings. Because CEL ints might come from the automatic widening of 32-bit
// values in protos, the JSON type is chosen dynamically based on the value.
Expand Down
6 changes: 3 additions & 3 deletions common/types/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestIntConvertToNative_Any(t *testing.T) {
}

func TestIntConvertToNative_Error(t *testing.T) {
val, err := Int(1).ConvertToNative(jsonStructType)
val, err := Int(1).ConvertToNative(JSONStructType)
if err == nil {
t.Errorf("Got '%v', expected error", val)
}
Expand Down Expand Up @@ -223,7 +223,7 @@ func TestIntConvertToNative_Int64(t *testing.T) {

func TestIntConvertToNative_Json(t *testing.T) {
// Value can be represented accurately as a JSON number.
val, err := Int(maxIntJSON).ConvertToNative(jsonValueType)
val, err := Int(maxIntJSON).ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
} else if !proto.Equal(val.(proto.Message),
Expand All @@ -232,7 +232,7 @@ func TestIntConvertToNative_Json(t *testing.T) {
}

// Value converts to a JSON decimal string.
val, err = Int(maxIntJSON + 1).ConvertToNative(jsonValueType)
val, err = Int(maxIntJSON + 1).ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
} else if !proto.Equal(val.(proto.Message), structpb.NewStringValue("9007199254740992")) {
Expand Down
8 changes: 4 additions & 4 deletions common/types/json_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestJsonListValueAdd(t *testing.T) {
structpb.NewNumberValue(2),
structpb.NewNumberValue(3)}})
list := listA.Add(listB).(traits.Lister)
nativeVal, err := list.ConvertToNative(jsonListValueType)
nativeVal, err := list.ConvertToNative(JSONListType)
if err != nil {
t.Error(err)
}
Expand All @@ -50,7 +50,7 @@ func TestJsonListValueAdd(t *testing.T) {
}
listC := NewStringList(reg, []string{"goodbye", "world"})
list = list.Add(listC).(traits.Lister)
nativeVal, err = list.ConvertToNative(jsonListValueType)
nativeVal, err = list.ConvertToNative(JSONListType)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -101,15 +101,15 @@ func TestJsonListValueConvertToNative_Json(t *testing.T) {
list := NewJSONList(newTestRegistry(t), &structpb.ListValue{Values: []*structpb.Value{
structpb.NewStringValue("hello"),
structpb.NewNumberValue(1)}})
listVal, err := list.ConvertToNative(jsonListValueType)
listVal, err := list.ConvertToNative(JSONListType)
if err != nil {
t.Error(err)
}
if listVal != list.Value().(proto.Message) {
t.Error("List did not convert to its underlying representation.")
}

val, err := list.ConvertToNative(jsonValueType)
val, err := list.ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
}
Expand Down
4 changes: 2 additions & 2 deletions common/types/json_struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestJsonStructConvertToNative_Json(t *testing.T) {
"first": structpb.NewStringValue("hello"),
"second": structpb.NewNumberValue(1)}}
mapVal := NewJSONStruct(newTestRegistry(t), structVal)
val, err := mapVal.ConvertToNative(jsonValueType)
val, err := mapVal.ConvertToNative(JSONValueType)
if err != nil {
t.Error(err)
}
Expand All @@ -50,7 +50,7 @@ func TestJsonStructConvertToNative_Json(t *testing.T) {
t.Errorf("Got '%v', expected '%v'", val, structVal)
}

strVal, err := mapVal.ConvertToNative(jsonStructType)
strVal, err := mapVal.ConvertToNative(JSONStructType)
if err != nil {
t.Error(err)
}
Expand Down
9 changes: 5 additions & 4 deletions common/types/json_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import (

// JSON type constants representing the reflected types of protobuf JSON values.
var (
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonListValueType = reflect.TypeOf(&structpb.ListValue{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
// JSONValueType describes the protobuf native type for a JSON value.
JSONValueType = reflect.TypeFor[*structpb.Value]()
JSONListType = reflect.TypeFor[*structpb.ListValue]()
JSONStructType = reflect.TypeFor[*structpb.Struct]()
JSONNullType = reflect.TypeFor[structpb.NullValue]()
)
9 changes: 6 additions & 3 deletions common/types/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ func (l *baseList) Contains(elem ref.Val) ref.Val {

// ConvertToNative implements the ref.Val interface method.
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
if typeDesc == reflect.TypeFor[any]() {
typeDesc = reflect.TypeFor[[]any]()
}
// If the underlying list value is assignable to the reflected type return it.
if reflect.TypeOf(l.value).AssignableTo(typeDesc) {
return l.value, nil
Expand All @@ -164,19 +167,19 @@ func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// Attempt to convert the list to a set of well known protobuf types.
switch typeDesc {
case anyValueType:
json, err := l.ConvertToNative(jsonListValueType)
json, err := l.ConvertToNative(JSONListType)
if err != nil {
return nil, err
}
return anypb.New(json.(proto.Message))
case jsonValueType, jsonListValueType:
case JSONValueType, JSONListType:
jsonValues, err :=
l.ConvertToNative(reflect.TypeOf([]*structpb.Value{}))
if err != nil {
return nil, err
}
jsonList := &structpb.ListValue{Values: jsonValues.([]*structpb.Value)}
if typeDesc == jsonListValueType {
if typeDesc == JSONListType {
return jsonList, nil
}
return structpb.NewListValue(jsonList), nil
Expand Down
14 changes: 7 additions & 7 deletions common/types/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func TestBaseListConvertToNative_Any(t *testing.T) {

func TestBaseListConvertToNative_Json(t *testing.T) {
list := NewDynamicList(newTestRegistry(t), []float64{1.0, 2.0})
val, err := list.ConvertToNative(jsonListValueType)
val, err := list.ConvertToNative(JSONListType)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -359,7 +359,7 @@ func TestConcatListConvertToNative_Json(t *testing.T) {
listA := NewDynamicList(reg, []float32{1.0, 2.0})
listB := NewDynamicList(reg, []string{"3"})
list := listA.Add(listB)
jsonVal, err := list.ConvertToNative(jsonValueType)
jsonVal, err := list.ConvertToNative(JSONValueType)
if err != nil {
t.Fatalf("Got error '%v', expected value", err)
}
Expand All @@ -379,7 +379,7 @@ func TestConcatListConvertToNative_Json(t *testing.T) {
// Test proto3 to JSON conversion.
listC := NewDynamicList(reg, []*dpb.Duration{{Seconds: 100}})
listConcat := listA.Add(listC)
jsonVal, err = listConcat.ConvertToNative(jsonValueType)
jsonVal, err = listConcat.ConvertToNative(JSONValueType)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -607,7 +607,7 @@ func TestStringListConvertToNative_ListInterface(t *testing.T) {
func TestStringListConvertToNative_Error(t *testing.T) {
reg := newTestRegistry(t)
list := NewStringList(reg, []string{"h", "e", "l", "p"})
_, err := list.ConvertToNative(jsonStructType)
_, err := list.ConvertToNative(JSONStructType)
if err == nil {
t.Error("Conversion of list to unsupported type did not error.")
}
Expand All @@ -616,7 +616,7 @@ func TestStringListConvertToNative_Error(t *testing.T) {
func TestStringListConvertToNative_Json(t *testing.T) {
reg := newTestRegistry(t)
list := NewStringList(reg, []string{"h", "e", "l", "p"})
jsonVal, err := list.ConvertToNative(jsonValueType)
jsonVal, err := list.ConvertToNative(JSONValueType)
if err != nil {
t.Errorf("Got '%v', expected '%v'", err, jsonVal)
}
Expand All @@ -634,7 +634,7 @@ func TestStringListConvertToNative_Json(t *testing.T) {
t.Errorf("got json '%v', expected %v", jsonTxt, outList)
}

jsonList, err := list.ConvertToNative(jsonListValueType)
jsonList, err := list.ConvertToNative(JSONListType)
if err != nil {
t.Errorf("Got '%v', expected '%v'", err, jsonList)
}
Expand Down Expand Up @@ -681,7 +681,7 @@ func TestValueListAdd(t *testing.T) {
func TestValueListConvertToNative_Json(t *testing.T) {
reg := newTestRegistry(t)
list := NewRefValList(reg, []ref.Val{String("hello"), String("world")})
jsonVal, err := list.ConvertToNative(jsonListValueType)
jsonVal, err := list.ConvertToNative(JSONListType)
if err != nil {
t.Errorf("Got '%v', expected '%v'", err, jsonVal)
}
Expand Down
Loading