Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions v1/ast/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,18 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error {
return newArgError(expr.Location, name, "too few arguments", getArgTypes(env, args), namedFargs)
}

pre := getArgTypes(env, args)

for i := range args {
// Check that pre-existing argument types are compatible with the expected types.
// Catching that case here avoids false negatives for builtins like sum([1, a]) where a is known to be a string.
if pre[i] != nil && !types.Nil(pre[i]) && !unifies(pre[i], fargs.Arg(i)) {
return newArgError(expr.Location, name, "invalid argument(s)", pre, namedFargs)
}

// unify1 infers types for untyped variables and checks resolved parts (constants, refs) inside partially-typed composites.
// The unifies pre-check above is skipped when the argument contains untyped variables,
// so unify1 is still needed to catch those errors.
if !unify1(env, args[i], fargs.Arg(i), false) {
post := make([]types.Type, len(args))
for i := range args {
Expand Down Expand Up @@ -549,6 +560,9 @@ func unify2Object(env *TypeEnv, a *Term, b *Term) bool {
return false
}

// unify1 walks into a term's structure (arrays, objects, sets, vars), checks
// compatibility against the expected type, and infers types for variables by
// assigning them in env. It uses unifies internally for leaf checks.
func unify1(env *TypeEnv, term *Term, tpe types.Type, union bool) bool {
switch v := term.Value.(type) {
case *Array:
Expand Down Expand Up @@ -839,6 +853,7 @@ func (rc *refChecker) checkRefLeaf(tpe types.Type, ref Ref, idx int) *Error {
return rc.checkRefLeaf(types.Values(tpe), ref, idx+1)
}

// unifies checks whether two types are compatible with each other.
func unifies(a, b types.Type) bool {

if a == nil || b == nil {
Expand Down Expand Up @@ -962,7 +977,14 @@ func unifiesObjects(a, b *types.Object) bool {

func unifiesObjectsStatic(a, b *types.Object) bool {
for _, k := range a.Keys() {
if !unifies(a.Select(k), b.Select(k)) {
tpeB := b.Select(k)
if tpeB == nil {
if a.DynamicValue() != nil {
continue
}
return false
}
if !unifies(a.Select(k), tpeB) {
return false
}
}
Expand Down Expand Up @@ -1174,7 +1196,7 @@ func removeDuplicate(list []Value) []Value {
func getArgTypes(env *TypeEnv, args []*Term) []types.Type {
pre := make([]types.Type, len(args))
for i := range args {
pre[i] = env.Get(args[i])
pre[i] = env.GetByValue(args[i].Value)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

}
return pre
}
Expand Down
92 changes: 92 additions & 0 deletions v1/ast/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,90 @@ allow if {
p if { [data.base.foo] }`,
expectedError: "policy.rego:3: rego_type_error: function data.base.foo used as reference, not called",
},
{
name: "wrong type used directly",
policy: `package p

s := sum([1, "foo"])`,
expectedError: "sum: invalid argument(s)",
},
{
name: "wrong type as reference",
policy: `package p

a := "foo"
s := sum([1, a])`,
expectedError: "sum: invalid argument(s)",
},
{
name: "wrong type as reference within rule",
policy: `package p

allow := s if {
a := "foo"
s := sum([1, a])
}`,
expectedError: "sum: invalid argument(s)",
},
{
name: "compare partial object",
policy: `package p

obj["a"] := input.a

obj["b"] := input.b

test_obj2 if {
{"a":"1"} == obj with input as {"a": "1"}
}`,
},
// this policy verifies this issue: https://github.com/open-policy-agent/opa/issues/6751
{
name: "compare two partial objects",
policy: `package p

obj1["a"] := input.a
obj1["b"] := input.b

obj2["x"] := input.x
obj2["y"] := input.y

test if {
obj1 == obj2
}`,
},
// this policy verifies this issue: https://github.com/open-policy-agent/opa/issues/5594
{
name: "compare two partial objects 2",
policy: `package test

obj["a"] := true

obj["b"] := "foo" if {
input.foo == "bar"
}

test_obj if {
obj == {
"a": true
} with input as {"foo": "baz"}
}`,
},
{
name: "sum with valid number args",
policy: `package p

a := 1
s := sum([1, a])`,
},
{
name: "sum with untyped var from input",
policy: `package p

allow := s if {
s := sum([1, input.a])
}`,
},
}

for _, tc := range tests {
Expand All @@ -2665,6 +2749,14 @@ p if { [data.base.foo] }`,
base := "package base\n" + body
compiler.Compile(map[string]*Module{"base": MustParseModuleWithOpts(base, pOpts), "policy.rego": module})

if tc.expectedError == "" {
if compiler.Failed() {
t.Fatalf("expected no error, but got %v", compiler.Errors.Error())
} else {
return
}
}

if !compiler.Failed() {
t.Fatal("expected error, got none")
}
Expand Down
Loading