Skip to content

Commit

Permalink
feat: implement compile time path propagation
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly committed Sep 24, 2024
1 parent 123686c commit e04bac5
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 73 deletions.
5 changes: 3 additions & 2 deletions pkg/apis/policy/v1alpha1/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/projection"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
)

// Any can be any type.
Expand All @@ -20,8 +21,8 @@ func NewAny(value any) Any {
}
}

func (t *Any) Compile(compilers compilers.Compilers) (projection.ScalarHandler, error) {
return projection.ParseScalar(t._value, compilers)
func (t *Any) Compile(path *field.Path, compilers compilers.Compilers) (projection.ScalarHandler, *field.Error) {
return projection.ParseScalar(path, t._value, compilers)
}

func (a *Any) MarshalJSON() ([]byte, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/policy/v1alpha1/assertion_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/kyverno/kyverno-json/pkg/core/assertion"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
)

// +k8s:deepcopy-gen=false
Expand All @@ -20,8 +21,8 @@ func NewAssertionTree(value any) AssertionTree {
}
}

func (t *AssertionTree) Compile(compilers compilers.Compilers) (assertion.Assertion, error) {
return assertion.Parse(t._tree, compilers)
func (t *AssertionTree) Compile(path *field.Path, compilers compilers.Compilers) (assertion.Assertion, *field.Error) {
return assertion.Parse(path, t._tree, compilers)
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
Expand Down
24 changes: 10 additions & 14 deletions pkg/core/assertion/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ type Assertion interface {
Assert(*field.Path, any, binding.Bindings) (field.ErrorList, error)
}

func Parse(assertion any, compiler compilers.Compilers) (Assertion, error) {
out, err := parse(nil, assertion, compiler)
if err != nil {
return nil, err
}
return out, nil
}

func parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) {
func Parse(path *field.Path, assertion any, compiler compilers.Compilers) (Assertion, *field.Error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(path, assertion, compiler)
Expand Down Expand Up @@ -68,7 +60,7 @@ func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
path := path.Index(i)
sub, err := parse(path, valueOf.Index(i).Interface(), compiler)
sub, err := Parse(path, valueOf.Index(i).Interface(), compiler)
if err != nil {
return nil, err
}
Expand All @@ -80,7 +72,7 @@ func parseSlice(path *field.Path, assertion any, compiler compilers.Compilers) (
// mapNode is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
type mapNode map[any]struct {
projection.Projection
*projection.Projection
Assertion
}

Expand Down Expand Up @@ -154,13 +146,17 @@ func parseMap(path *field.Path, assertion any, compiler compilers.Compilers) (ma
key := iter.Key().Interface()
value := iter.Value().Interface()
path := path.Child(fmt.Sprint(key))
assertion, err := parse(path, value, compiler)
projection, err := projection.ParseMapKey(path, key, compiler)
if err != nil {
return nil, err
}
assertion, err := Parse(path, value, compiler)
if err != nil {
return nil, err
}
entry := assertions[key]
entry.Projection = projection
entry.Assertion = assertion
entry.Projection = projection.ParseMapKey(key, compiler)
assertions[key] = entry
}
return assertions, nil
Expand All @@ -184,7 +180,7 @@ func (node scalarNode) Assert(path *field.Path, value any, bindings binding.Bind
}

func parseScalar(path *field.Path, in any, compiler compilers.Compilers) (scalarNode, *field.Error) {
proj, err := projection.ParseScalar(in, compiler)
proj, err := projection.ParseScalar(path, in, compiler)
if err != nil {
return nil, field.InternalError(path, err)
}
Expand Down
18 changes: 10 additions & 8 deletions pkg/core/assertion/assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,17 @@ func TestAssert(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler)
parsed, err := Parse(nil, tt.assertion, compiler)
assert.NoError(t, err)
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
{
got, err := parsed.Assert(nil, tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
}
assert.Equal(t, tt.want, got)
})
}
}
Expand All @@ -80,7 +82,7 @@ func TestParse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
parsed, err := Parse(tt.assertion, compiler)
parsed, err := Parse(nil, tt.assertion, compiler)
if tt.wantErr {
assert.Error(t, err)
} else {
Expand Down
35 changes: 15 additions & 20 deletions pkg/core/projection/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package projection
import (
"errors"
"reflect"
"sync"

"github.com/jmespath-community/go-jmespath/pkg/binding"
"github.com/kyverno/kyverno-json/pkg/core/compilers"
"github.com/kyverno/kyverno-json/pkg/core/expression"
reflectutils "github.com/kyverno/kyverno-json/pkg/utils/reflect"
"k8s.io/apimachinery/pkg/util/validation/field"
)

type (
Expand All @@ -27,7 +27,8 @@ type Projection struct {
Handler MapKeyHandler
}

func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
func ParseMapKey(path *field.Path, in any, compiler compilers.Compilers) (*Projection, *field.Error) {
var projection Projection
switch typed := in.(type) {
case string:
// 1. if we have a string, parse the expression
Expand All @@ -38,14 +39,11 @@ func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
projection.Binding = expr.Binding
// 3. compute the projection func
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
program, err := compiler.Compile(expr.Statement)
if err != nil {
return nil, field.Invalid(path, expr.Statement, err.Error())
}
projection.Handler = func(value any, bindings binding.Bindings) (any, bool, error) {
program, err := compile()
if err != nil {
return nil, false, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, false, err
Expand Down Expand Up @@ -83,28 +81,25 @@ func ParseMapKey(in any, compiler compilers.Compilers) (projection Projection) {
return nil, false, errors.New("projection not recognized")
}
}
return
return &projection, nil
}

func ParseScalar(in any, compiler compilers.Compilers) (ScalarHandler, error) {
func ParseScalar(path *field.Path, in any, compiler compilers.Compilers) (ScalarHandler, *field.Error) {
switch typed := in.(type) {
case string:
expr := expression.Parse(typed)
if expr.Foreach {
return nil, errors.New("foreach is not supported in scalar projections")
return nil, field.Invalid(path, typed, "foreach is not supported in scalar projections")
}
if expr.Binding != "" {
return nil, errors.New("binding is not supported in scalar projections")
return nil, field.Invalid(path, typed, "binding is not supported in scalar projections")
}
if compiler := compiler.Compiler(expr.Compiler); compiler != nil {
compile := sync.OnceValues(func() (compilers.Program, error) {
return compiler.Compile(expr.Statement)
})
program, err := compiler.Compile(expr.Statement)
if err != nil {
return nil, field.Invalid(path, expr.Statement, err.Error())
}
return func(value any, bindings binding.Bindings) (any, error) {
program, err := compile()
if err != nil {
return nil, err
}
projected, err := program(value, bindings)
if err != nil {
return nil, err
Expand Down
19 changes: 11 additions & 8 deletions pkg/core/projection/projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,18 @@ func TestParseMap(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compiler := compilers.DefaultCompilers
proj := ParseMapKey(tt.key, compiler)
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
proj, err := ParseMapKey(nil, tt.key, compiler)
assert.Nil(t, err)
{
got, found, err := proj.Handler(tt.value, tt.bindings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
}
assert.Equal(t, tt.wantFound, found)
assert.Equal(t, tt.want, got)
})
}
}
38 changes: 19 additions & 19 deletions pkg/json-engine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ func (c *compiler) compileContextEntry(
path *field.Path,
compilers compilers.Compilers,
in v1alpha1.ContextEntry,
) (func(any, binding.Bindings) binding.Bindings, error) {
) (func(any, binding.Bindings) binding.Bindings, *field.Error) {
if in.Compiler != nil {
compilers = compilers.WithDefaultCompiler(string(*in.Compiler))
}
handler, err := in.Variable.Compile(compilers)
handler, err := in.Variable.Compile(path.Child("variable"), compilers)
if err != nil {
return nil, field.InternalError(path.Child("variable"), err)
return nil, err
}
return func(resource any, bindings binding.Bindings) binding.Bindings {
binding := binding.NewDelegate(
Expand All @@ -45,7 +45,7 @@ func (c *compiler) compileContextEntries(
path *field.Path,
compilers compilers.Compilers,
in ...v1alpha1.ContextEntry,
) (func(any, binding.Bindings) binding.Bindings, error) {
) (func(any, binding.Bindings) binding.Bindings, *field.Error) {
var out []func(any, binding.Bindings) binding.Bindings
for i, entry := range in {
entry, err := c.compileContextEntry(path.Index(i), compilers, entry)
Expand All @@ -66,7 +66,7 @@ func (c *compiler) compileMatch(
path *field.Path,
compilers compilers.Compilers,
in *v1alpha1.Match,
) (func(any, binding.Bindings) (field.ErrorList, error), error) {
) (func(any, binding.Bindings) (field.ErrorList, error), *field.Error) {
if in == nil {
return nil, nil
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func (c *compiler) compileAssert(
path *field.Path,
compilers compilers.Compilers,
in v1alpha1.Assert,
) (func(any, binding.Bindings) (Results, error), error) {
) (func(any, binding.Bindings) (Results, error), *field.Error) {
if in.Compiler != nil {
compilers = compilers.WithDefaultCompiler(string(*in.Compiler))
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func (c *compiler) compileAssertions(
path *field.Path,
compilers compilers.Compilers,
in ...v1alpha1.Assertion,
) ([]func(any, binding.Bindings) (Result, error), error) {
) ([]func(any, binding.Bindings) (Result, error), *field.Error) {
var out []func(any, binding.Bindings) (Result, error)
for i, in := range in {
if in, err := c.compileAssertion(path.Index(i), compilers, in); err != nil {
Expand All @@ -181,7 +181,7 @@ func (c *compiler) compileAssertion(
path *field.Path,
compilers compilers.Compilers,
in v1alpha1.Assertion,
) (func(any, binding.Bindings) (Result, error), error) {
) (func(any, binding.Bindings) (Result, error), *field.Error) {
if in.Compiler != nil {
compilers = compilers.WithDefaultCompiler(string(*in.Compiler))
}
Expand All @@ -206,7 +206,7 @@ func (c *compiler) compileAssertionTrees(
path *field.Path,
compilers compilers.Compilers,
in ...v1alpha1.AssertionTree,
) ([]func(any, binding.Bindings) (field.ErrorList, error), error) {
) ([]func(any, binding.Bindings) (field.ErrorList, error), *field.Error) {
var out []func(any, binding.Bindings) (field.ErrorList, error)
for i, in := range in {
if in, err := c.compileAssertionTree(path.Index(i), compilers, in); err != nil {
Expand All @@ -222,8 +222,8 @@ func (c *compiler) compileAssertionTree(
path *field.Path,
compilers compilers.Compilers,
in v1alpha1.AssertionTree,
) (func(any, binding.Bindings) (field.ErrorList, error), error) {
check, err := in.Compile(compilers)
) (func(any, binding.Bindings) (field.ErrorList, error), *field.Error) {
check, err := in.Compile(path, compilers)
if err != nil {
return nil, err
}
Expand All @@ -233,18 +233,18 @@ func (c *compiler) compileAssertionTree(
}

func (c *compiler) compileIdentifier(
_ *field.Path,
path *field.Path,
compilers compilers.Compilers,
in string,
) (func(any, binding.Bindings) string, error) {
) (func(any, binding.Bindings) string, *field.Error) {
if in == "" {
return func(resource any, bindings binding.Bindings) string {
return ""
}, nil
}
program, err := compilers.Jp.Compile(in)
if err != nil {
return nil, err
return nil, field.InternalError(path, err)
}
return func(resource any, bindings binding.Bindings) string {
result, err := program(resource, bindings)
Expand All @@ -260,7 +260,7 @@ func (c *compiler) compileFeedbacks(
path *field.Path,
compilers compilers.Compilers,
in ...v1alpha1.Feedback,
) (func(any, binding.Bindings) map[string]Feedback, error) {
) (func(any, binding.Bindings) map[string]Feedback, *field.Error) {
if len(in) == 0 {
return func(any, binding.Bindings) map[string]Feedback {
return nil
Expand All @@ -284,14 +284,14 @@ func (c *compiler) compileFeedbacks(
}

func (c *compiler) compileFeedback(
_ *field.Path,
path *field.Path,
compilers compilers.Compilers,
in v1alpha1.Feedback,
) (func(any, binding.Bindings) Feedback, error) {
) (func(any, binding.Bindings) Feedback, *field.Error) {
if in.Compiler != nil {
compilers = compilers.WithDefaultCompiler(string(*in.Compiler))
}
handler, err := in.Value.Compile(compilers)
handler, err := in.Value.Compile(path.Child("value"), compilers)
if err != nil {
return nil, err
}
Expand All @@ -310,7 +310,7 @@ func (c *compiler) compileRule(
path *field.Path,
compilers compilers.Compilers,
in v1alpha1.ValidatingRule,
) (func(any, binding.Bindings) *RuleResponse, error) {
) (func(any, binding.Bindings) *RuleResponse, *field.Error) {
if in.Compiler != nil {
compilers = compilers.WithDefaultCompiler(string(*in.Compiler))
}
Expand Down

0 comments on commit e04bac5

Please sign in to comment.