diff --git a/v1/ast/check.go b/v1/ast/check.go index 6e4d8ddd74..0aa7dfab1b 100644 --- a/v1/ast/check.go +++ b/v1/ast/check.go @@ -419,7 +419,18 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error { return newArgError(expr.Location, name, "too few arguments", getArgTypes(env, args), namedFargs) } + pre := getArgTypes(env, args) + for i := range args { + // Check that pre-existing argument types are compatible with the expected types. + // Catching that case here avoids false negatives for builtins like sum([1, a]) where a is known to be a string. + if pre[i] != nil && !types.Nil(pre[i]) && !unifies(pre[i], fargs.Arg(i)) { + return newArgError(expr.Location, name, "invalid argument(s)", pre, namedFargs) + } + + // unify1 infers types for untyped variables and checks resolved parts (constants, refs) inside partially-typed composites. + // The unifies pre-check above is skipped when the argument contains untyped variables, + // so unify1 is still needed to catch those errors. if !unify1(env, args[i], fargs.Arg(i), false) { post := make([]types.Type, len(args)) for i := range args { @@ -549,6 +560,9 @@ func unify2Object(env *TypeEnv, a *Term, b *Term) bool { return false } +// unify1 walks into a term's structure (arrays, objects, sets, vars), checks +// compatibility against the expected type, and infers types for variables by +// assigning them in env. It uses unifies internally for leaf checks. func unify1(env *TypeEnv, term *Term, tpe types.Type, union bool) bool { switch v := term.Value.(type) { case *Array: @@ -839,6 +853,7 @@ func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error { return rc.checkRefLeaf(types.Values(tpe), ref, idx+1) } +// unifies checks whether two types are compatible with each other. func unifies(a, b types.Type) bool { if a == nil || b == nil { @@ -962,7 +977,14 @@ func unifiesObjects(a, b *types.Object) bool { func unifiesObjectsStatic(a, b *types.Object) bool { for _, k := range a.Keys() { - if !unifies(a.Select(k), b.Select(k)) { + tpeB := b.Select(k) + if tpeB == nil { + if a.DynamicValue() != nil { + continue + } + return false + } + if !unifies(a.Select(k), tpeB) { return false } } @@ -1174,7 +1196,7 @@ func removeDuplicate(list []Value) []Value { func getArgTypes(env *TypeEnv, args []*Term) []types.Type { pre := make([]types.Type, len(args)) for i := range args { - pre[i] = env.Get(args[i]) + pre[i] = env.GetByValue(args[i].Value) } return pre } diff --git a/v1/ast/check_test.go b/v1/ast/check_test.go index a63fe7712f..deaaf60ba3 100644 --- a/v1/ast/check_test.go +++ b/v1/ast/check_test.go @@ -2643,6 +2643,90 @@ allow if { p if { [data.base.foo] }`, expectedError: "policy.rego:3: rego_type_error: function data.base.foo used as reference, not called", }, + { + name: "wrong type used directly", + policy: `package p + +s := sum([1, "foo"])`, + expectedError: "sum: invalid argument(s)", + }, + { + name: "wrong type as reference", + policy: `package p + +a := "foo" +s := sum([1, a])`, + expectedError: "sum: invalid argument(s)", + }, + { + name: "wrong type as reference within rule", + policy: `package p + +allow := s if { + a := "foo" + s := sum([1, a]) +}`, + expectedError: "sum: invalid argument(s)", + }, + { + name: "compare partial object", + policy: `package p + +obj["a"] := input.a + +obj["b"] := input.b + +test_obj2 if { + {"a":"1"} == obj with input as {"a": "1"} +}`, + }, + // this policy verifies this issue: https://github.com/open-policy-agent/opa/issues/6751 + { + name: "compare two partial objects", + policy: `package p + +obj1["a"] := input.a +obj1["b"] := input.b + +obj2["x"] := input.x +obj2["y"] := input.y + +test if { + obj1 == obj2 +}`, + }, + // this policy verifies this issue: https://github.com/open-policy-agent/opa/issues/5594 + { + name: "compare two partial objects 2", + policy: `package test + +obj["a"] := true + +obj["b"] := "foo" if { + input.foo == "bar" +} + +test_obj if { + obj == { + "a": true + } with input as {"foo": "baz"} +}`, + }, + { + name: "sum with valid number args", + policy: `package p + +a := 1 +s := sum([1, a])`, + }, + { + name: "sum with untyped var from input", + policy: `package p + +allow := s if { + s := sum([1, input.a]) +}`, + }, } for _, tc := range tests { @@ -2665,6 +2749,14 @@ p if { [data.base.foo] }`, base := "package base\n" + body compiler.Compile(map[string]*Module{"base": MustParseModuleWithOpts(base, pOpts), "policy.rego": module}) + if tc.expectedError == "" { + if compiler.Failed() { + t.Fatalf("expected no error, but got %v", compiler.Errors.Error()) + } else { + return + } + } + if !compiler.Failed() { t.Fatal("expected error, got none") }