Skip to content

Commit

Permalink
feat: implement compile time path propagation (#518)
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly authored Sep 24, 2024
1 parent 123686c commit 9b02801
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 84 deletions.
5 changes: 3 additions & 2 deletions pkg/apis/policy/v1alpha1/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/projection"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
)

// Any can be any type.
Expand All @@ -20,8 +21,8 @@ func NewAny(value any) Any {
}
}

func (t *Any) Compile(compilers compilers.Compilers) (projection.ScalarHandler, error) {
return projection.ParseScalar(t._value, compilers)
func (t *Any) Compile(path *field.Path, compilers compilers.Compilers) (projection.ScalarHandler, *field.Error) {
return projection.ParseScalar(path, t._value, compilers)
}

func (a *Any) MarshalJSON() ([]byte, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/policy/v1alpha1/assertion_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/kyverno/kyverno-json/pkg/core/assertion"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
)

// +k8s:deepcopy-gen=false
Expand All @@ -20,8 +21,8 @@ func NewAssertionTree(value any) AssertionTree {
}
}

func (t *AssertionTree) Compile(compilers compilers.Compilers) (assertion.Assertion, error) {
return assertion.Parse(t._tree, compilers)
func (t *AssertionTree) Compile(path *field.Path, compilers compilers.Compilers) (assertion.Assertion, *field.Error) {
return assertion.Parse(path, t._tree, compilers)
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
Expand Down
26 changes: 11 additions & 15 deletions pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ type Assertion interface {
Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error)
}

func Parse(assertion any, compiler compilers.Compilers) (Assertion, error) {
out, err := parse(nil, assertion, compiler)
if err != nil {
return nil, err
}
return out, nil
}

func parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) {
func Parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(path, assertion, compiler)
Expand Down Expand Up @@ -68,7 +60,7 @@ func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
path := path.Index(i)
sub, err := parse(path, valueOf.Index(i).Interface(), compiler)
sub, err := Parse(path, valueOf.Index(i).Interface(), compiler)
if err != nil {
return nil, err
}
Expand All @@ -80,7 +72,7 @@ func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (
// mapNode is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
type mapNode map[any]struct {
projection.Projection
*projection.Projection
Assertion
}

Expand Down Expand Up @@ -154,13 +146,17 @@ func parseMap(path *field.Path, assertion any, compiler compilers.Compilers) (ma
key := iter.Key().Interface()
value := iter.Value().Interface()
path := path.Child(fmt.Sprint(key))
assertion, err := parse(path, value, compiler)
projection, err := projection.ParseMapKey(path, key, compiler)
if err != nil {
return nil, err
}
assertion, err := Parse(path, value, compiler)
if err != nil {
return nil, err
}
entry := assertions[key]
entry.Projection = projection
entry.Assertion = assertion
entry.Projection = projection.ParseMapKey(key, compiler)
assertions[key] = entry
}
return assertions, nil
Expand All @@ -184,9 +180,9 @@ func (node scalarNode) Assert(path *field.Path, value any, bindings binding.Bind
}

func parseScalar(path *field.Path, in any, compiler compilers.Compilers) (scalarNode, *field.Error) {
proj, err := projection.ParseScalar(in, compiler)
proj, err := projection.ParseScalar(path, in, compiler)
if err != nil {
return nil, field.InternalError(path, err)
return nil, err
}
return proj, nil
}
Expand Down
44 changes: 26 additions & 18 deletions pkg/core/assertion/assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,51 @@ func TestAssert(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler)
assert.NoError(t, err)
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
parsed, err := Parse(nil, tt.assertion, compiler)
assert.Nil(t, err)
{
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
}
assert.Equal(t, tt.want, got)
})
}
}

func TestParse(t *testing.T) {
tests := []struct {
name string
assertion any
want field.ErrorList
wantErr bool
name string
assertion any
wantAssertion bool
wantErr *field.Error
}{{
name: "bad scalar",
assertion: map[string]any{
"foo": map[string]any{
"bar": "~.(`42`)",
},
},
wantErr: true,
wantAssertion: false,
wantErr: &field.Error{
Type: field.ErrorTypeInvalid,
Field: "foo.bar",
BadValue: "~.(`42`)",
Detail: "foreach is not supported in scalar projections",
},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
parsed, err := Parse(nil, tt.assertion, compiler)
assert.Equal(t, tt.wantErr, err)
if tt.wantAssertion {
assert.NotNil(t, parsed)
} else {
assert.Nil(t, parsed)
}
})
}
Expand Down
35 changes: 15 additions & 20 deletions pkg/core/projection/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package projection
import (
"errors"
"reflect"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/expression"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
"k8s.io/apimachinery/pkg/util/validation/field"
)

type (
Expand All @@ -27,7 +27,8 @@ type Projection struct {
Handler MapKeyHandler
}

func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
func ParseMapKey(path *field.Path, in any, compiler compilers.Compilers) (*Projection, *field.Error) {
var projection Projection
switch typed := in.(type) {
case string:
// 1. if we have a string, parse the expression
Expand All @@ -38,14 +39,11 @@ func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
projection.Binding = expr.Binding
// 3. compute the projection func
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
program, err := compiler.Compile(expr.Statement)
if err != nil {
return nil, field.Invalid(path, expr.Statement, err.Error())
}
projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) {
program, err := compile()
if err != nil {
return nil, false, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, false, err
Expand Down Expand Up @@ -83,28 +81,25 @@ func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
return nil, false, errors.New("projection not recognized")
}
}
return
return &projection, nil
}

func ParseScalar(in any, compiler compilers.Compilers) (ScalarHandler, error) {
func ParseScalar(path *field.Path, in any, compiler compilers.Compilers) (ScalarHandler, *field.Error) {
switch typed := in.(type) {
case string:
expr := expression.Parse(typed)
if expr.Foreach {
return nil, errors.New("foreach is not supported in scalar projections")
return nil, field.Invalid(path, typed, "foreach is not supported in scalar projections")
}
if expr.Binding != "" {
return nil, errors.New("binding is not supported in scalar projections")
return nil, field.Invalid(path, typed, "binding is not supported in scalar projections")
}
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
program, err := compiler.Compile(expr.Statement)
if err != nil {
return nil, field.Invalid(path, expr.Statement, err.Error())
}
return func(value any, bindings binding.Bindings) (any, error) {
program, err := compile()
if err != nil {
return nil, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, err
Expand Down
19 changes: 11 additions & 8 deletions pkg/core/projection/projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,18 @@ func TestParseMap(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
proj := ParseMapKey(tt.key, compiler)
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
proj, err := ParseMapKey(nil, tt.key, compiler)
assert.Nil(t, err)
{
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
}
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
})
}
}
Loading

0 comments on commit 9b02801

Please sign in to comment.