Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement compile time path propagation #518

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pkg/apis/policy/v1alpha1/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/projection"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
)

// Any can be any type.
Expand All @@ -20,8 +21,8 @@ func NewAny(value any) Any {
}
}

func (t *Any) Compile(compilers compilers.Compilers) (projection.ScalarHandler, error) {
return projection.ParseScalar(t._value, compilers)
func (t *Any) Compile(path *field.Path, compilers compilers.Compilers) (projection.ScalarHandler, *field.Error) {
return projection.ParseScalar(path, t._value, compilers)
}

func (a *Any) MarshalJSON() ([]byte, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/policy/v1alpha1/assertion_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/kyverno/kyverno-json/pkg/core/assertion"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
)

// +k8s:deepcopy-gen=false
Expand All @@ -20,8 +21,8 @@ func NewAssertionTree(value any) AssertionTree {
}
}

func (t *AssertionTree) Compile(compilers compilers.Compilers) (assertion.Assertion, error) {
return assertion.Parse(t._tree, compilers)
func (t *AssertionTree) Compile(path *field.Path, compilers compilers.Compilers) (assertion.Assertion, *field.Error) {
return assertion.Parse(path, t._tree, compilers)
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
Expand Down
26 changes: 11 additions & 15 deletions pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ type Assertion interface {
Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error)
}

func Parse(assertion any, compiler compilers.Compilers) (Assertion, error) {
out, err := parse(nil, assertion, compiler)
if err != nil {
return nil, err
}
return out, nil
}

func parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) {
func Parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(path, assertion, compiler)
Expand Down Expand Up @@ -68,7 +60,7 @@ func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
path := path.Index(i)
sub, err := parse(path, valueOf.Index(i).Interface(), compiler)
sub, err := Parse(path, valueOf.Index(i).Interface(), compiler)
if err != nil {
return nil, err
}
Expand All @@ -80,7 +72,7 @@ func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (
// mapNode is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
type mapNode map[any]struct {
projection.Projection
*projection.Projection
Assertion
}

Expand Down Expand Up @@ -154,13 +146,17 @@ func parseMap(path *field.Path, assertion any, compiler compilers.Compilers) (ma
key := iter.Key().Interface()
value := iter.Value().Interface()
path := path.Child(fmt.Sprint(key))
assertion, err := parse(path, value, compiler)
projection, err := projection.ParseMapKey(path, key, compiler)
if err != nil {
return nil, err
}
assertion, err := Parse(path, value, compiler)
if err != nil {
return nil, err
}
entry := assertions[key]
entry.Projection = projection
entry.Assertion = assertion
entry.Projection = projection.ParseMapKey(key, compiler)
assertions[key] = entry
}
return assertions, nil
Expand All @@ -184,9 +180,9 @@ func (node scalarNode) Assert(path *field.Path, value any, bindings binding.Bind
}

func parseScalar(path *field.Path, in any, compiler compilers.Compilers) (scalarNode, *field.Error) {
proj, err := projection.ParseScalar(in, compiler)
proj, err := projection.ParseScalar(path, in, compiler)
if err != nil {
return nil, field.InternalError(path, err)
return nil, err
}
return proj, nil
}
Expand Down
44 changes: 26 additions & 18 deletions pkg/core/assertion/assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,51 @@ func TestAssert(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler)
assert.NoError(t, err)
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
parsed, err := Parse(nil, tt.assertion, compiler)
assert.Nil(t, err)
{
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
}
assert.Equal(t, tt.want, got)
})
}
}

func TestParse(t *testing.T) {
tests := []struct {
name string
assertion any
want field.ErrorList
wantErr bool
name string
assertion any
wantAssertion bool
wantErr *field.Error
}{{
name: "bad scalar",
assertion: map[string]any{
"foo": map[string]any{
"bar": "~.(`42`)",
},
},
wantErr: true,
wantAssertion: false,
wantErr: &field.Error{
Type: field.ErrorTypeInvalid,
Field: "foo.bar",
BadValue: "~.(`42`)",
Detail: "foreach is not supported in scalar projections",
},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
parsed, err := Parse(nil, tt.assertion, compiler)
assert.Equal(t, tt.wantErr, err)
if tt.wantAssertion {
assert.NotNil(t, parsed)
} else {
assert.Nil(t, parsed)
}
})
}
Expand Down
35 changes: 15 additions & 20 deletions pkg/core/projection/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package projection
import (
"errors"
"reflect"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/expression"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
"k8s.io/apimachinery/pkg/util/validation/field"
)

type (
Expand All @@ -27,7 +27,8 @@ type Projection struct {
Handler MapKeyHandler
}

func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
func ParseMapKey(path *field.Path, in any, compiler compilers.Compilers) (*Projection, *field.Error) {
var projection Projection
switch typed := in.(type) {
case string:
// 1. if we have a string, parse the expression
Expand All @@ -38,14 +39,11 @@ func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
projection.Binding = expr.Binding
// 3. compute the projection func
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
program, err := compiler.Compile(expr.Statement)
if err != nil {
return nil, field.Invalid(path, expr.Statement, err.Error())
}
projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) {
program, err := compile()
if err != nil {
return nil, false, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, false, err
Expand Down Expand Up @@ -83,28 +81,25 @@ func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
return nil, false, errors.New("projection not recognized")
}
}
return
return &projection, nil
}

func ParseScalar(in any, compiler compilers.Compilers) (ScalarHandler, error) {
func ParseScalar(path *field.Path, in any, compiler compilers.Compilers) (ScalarHandler, *field.Error) {
switch typed := in.(type) {
case string:
expr := expression.Parse(typed)
if expr.Foreach {
return nil, errors.New("foreach is not supported in scalar projections")
return nil, field.Invalid(path, typed, "foreach is not supported in scalar projections")
}
if expr.Binding != "" {
return nil, errors.New("binding is not supported in scalar projections")
return nil, field.Invalid(path, typed, "binding is not supported in scalar projections")
}
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
program, err := compiler.Compile(expr.Statement)
if err != nil {
return nil, field.Invalid(path, expr.Statement, err.Error())
}
return func(value any, bindings binding.Bindings) (any, error) {
program, err := compile()
if err != nil {
return nil, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, err
Expand Down
19 changes: 11 additions & 8 deletions pkg/core/projection/projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,18 @@ func TestParseMap(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
proj := ParseMapKey(tt.key, compiler)
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
proj, err := ParseMapKey(nil, tt.key, compiler)
assert.Nil(t, err)
{
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
}
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
})
}
}
Loading