Skip to content

Commit

Permalink
feat: hold compiled assertion in api
Browse files Browse the repository at this point in the history
Signed-off-by: Charles-Edouard Brétéché <[email protected]>
  • Loading branch information
eddycharly committed Sep 18, 2024
1 parent f349c31 commit 9f3325d
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 110 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ codegen-deepcopy: $(PACKAGE_SHIM) $(DEEPCOPY_GEN) ## Generate deep copy function
codegen-crds: $(CONTROLLER_GEN) ## Generate CRDs
@echo Generate crds... >&2
@rm -rf $(CRDS_PATH)
@$(CONTROLLER_GEN) paths=./pkg/apis/... crd:crdVersions=v1 output:dir=$(CRDS_PATH)
@$(CONTROLLER_GEN) paths=./pkg/apis/... crd:crdVersions=v1,ignoreUnexportedFields=true output:dir=$(CRDS_PATH)
@echo Copy generated CRDs to embed in the CLI... >&2
@rm -rf pkg/data/crds && mkdir -p pkg/data/crds
@cp $(CRDS_PATH)/* pkg/data/crds
Expand Down
48 changes: 5 additions & 43 deletions pkg/apis/policy/v1alpha1/any.go
Original file line number Diff line number Diff line change
@@ -1,65 +1,27 @@
package v1alpha1

import (
"fmt"

"k8s.io/apimachinery/pkg/util/json"
)

func deepCopy(in any) any {
if in == nil {
return nil
}
switch in := in.(type) {
case string:
return in
case int:
return in
case int32:
return in
case int64:
return in
case float32:
return in
case float64:
return in
case bool:
return in
case []any:
var out []any
for _, in := range in {
out = append(out, deepCopy(in))
}
return out
case map[string]any:
out := map[string]any{}
for k, in := range in {
out[k] = deepCopy(in)
}
return out
}
panic(fmt.Sprintf("deep copy failed - unrecognized type %T", in))
}

// Any can be any type.
// +k8s:deepcopy-gen=false
// +kubebuilder:validation:XPreserveUnknownFields
// +kubebuilder:validation:Type:=""
type Any struct {
// +optional
value any `json:"-"`
_value any
}

func NewAny(value any) Any {
return Any{value}
}

func (t *Any) Value() any {
return t.value
return t._value
}

func (in *Any) DeepCopyInto(out *Any) {
out.value = deepCopy(in.value)
out._value = deepCopy(in._value)
}

func (in *Any) DeepCopy() *Any {
Expand All @@ -72,7 +34,7 @@ func (in *Any) DeepCopy() *Any {
}

func (a *Any) MarshalJSON() ([]byte, error) {
return json.Marshal(a.value)
return json.Marshal(a._value)
}

func (a *Any) UnmarshalJSON(data []byte) error {
Expand All @@ -81,6 +43,6 @@ func (a *Any) UnmarshalJSON(data []byte) error {
if err != nil {
return err
}
a.value = v
a._value = v
return nil
}
39 changes: 0 additions & 39 deletions pkg/apis/policy/v1alpha1/assertion.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,5 @@
package v1alpha1

import (
"k8s.io/apimachinery/pkg/util/json"
)

// +k8s:deepcopy-gen=false
// +kubebuilder:validation:XPreserveUnknownFields
// +kubebuilder:validation:Type:=""
// AssertionTree represents an assertion tree.
type AssertionTree struct {
// +optional
tree any `json:"-"`
}

func NewAssertionTree(value any) AssertionTree {
return AssertionTree{value}
}

func (t *AssertionTree) Raw() any {
return t.tree
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
return json.Marshal(a.tree)
}

func (a *AssertionTree) UnmarshalJSON(data []byte) error {
var v any
err := json.Unmarshal(data, &v)
if err != nil {
return err
}
a.tree = v
return nil
}

func (in *AssertionTree) DeepCopyInto(out *AssertionTree) {
out.tree = deepCopy(in.tree)
}

// Assertion contains an assertion tree associated with a message.
type Assertion struct {
// Message is the message associated message.
Expand Down
56 changes: 56 additions & 0 deletions pkg/apis/policy/v1alpha1/assertion_tree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package v1alpha1

import (
"context"
"sync"

"github.com/kyverno/kyverno-json/pkg/engine/assert"
"k8s.io/apimachinery/pkg/util/json"
)

// +k8s:deepcopy-gen=false
// +kubebuilder:validation:XPreserveUnknownFields
// +kubebuilder:validation:Type:=""
// AssertionTree represents an assertion tree.
type AssertionTree struct {
_tree any
_assertion func() (assert.Assertion, error)
}

func NewAssertionTree(value any) AssertionTree {
return AssertionTree{
_tree: value,
_assertion: sync.OnceValues(func() (assert.Assertion, error) {
return assert.Parse(context.Background(), value)
}),

Check warning on line 25 in pkg/apis/policy/v1alpha1/assertion_tree.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/assertion_tree.go#L20-L25

Added lines #L20 - L25 were not covered by tests
}
}

func (t *AssertionTree) Assertion() (assert.Assertion, error) {
if t._tree == nil {
return nil, nil
}
return t._assertion()

Check warning on line 33 in pkg/apis/policy/v1alpha1/assertion_tree.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/assertion_tree.go#L29-L33

Added lines #L29 - L33 were not covered by tests
}

func (a *AssertionTree) MarshalJSON() ([]byte, error) {
return json.Marshal(a._tree)

Check warning on line 37 in pkg/apis/policy/v1alpha1/assertion_tree.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/assertion_tree.go#L36-L37

Added lines #L36 - L37 were not covered by tests
}

func (a *AssertionTree) UnmarshalJSON(data []byte) error {
var v any
err := json.Unmarshal(data, &v)
if err != nil {
return err
}
a._tree = v
a._assertion = sync.OnceValues(func() (assert.Assertion, error) {
return assert.Parse(context.Background(), v)
})
return nil

Check warning on line 50 in pkg/apis/policy/v1alpha1/assertion_tree.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/assertion_tree.go#L40-L50

Added lines #L40 - L50 were not covered by tests
}

func (in *AssertionTree) DeepCopyInto(out *AssertionTree) {
out._tree = deepCopy(in._tree)
out._assertion = in._assertion

Check warning on line 55 in pkg/apis/policy/v1alpha1/assertion_tree.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/assertion_tree.go#L53-L55

Added lines #L53 - L55 were not covered by tests
}
40 changes: 40 additions & 0 deletions pkg/apis/policy/v1alpha1/deep_copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package v1alpha1

import (
"fmt"
)

func deepCopy(in any) any {
if in == nil {
return nil
}
switch in := in.(type) {
case string:
return in
case int:
return in
case int32:
return in
case int64:
return in
case float32:
return in
case float64:
return in
case bool:
return in

Check warning on line 25 in pkg/apis/policy/v1alpha1/deep_copy.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/deep_copy.go#L16-L25

Added lines #L16 - L25 were not covered by tests
case []any:
var out []any
for _, in := range in {
out = append(out, deepCopy(in))
}
return out
case map[string]any:
out := map[string]any{}
for k, in := range in {
out[k] = deepCopy(in)
}
return out
}
panic(fmt.Sprintf("deep copy failed - unrecognized type %T", in))

Check warning on line 39 in pkg/apis/policy/v1alpha1/deep_copy.go

View check run for this annotation

Codecov / codecov/patch

pkg/apis/policy/v1alpha1/deep_copy.go#L39

Added line #L39 was not covered by tests
}
2 changes: 1 addition & 1 deletion pkg/engine/assert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestAssert(t *testing.T) {
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parsed, err := Parse(context.TODO(), nil, tt.assertion)
parsed, err := Parse(context.TODO(), tt.assertion)
tassert.NoError(t, err)
got, err := Assert(context.TODO(), nil, parsed, tt.value, tt.bindings)
if tt.wantErr {
Expand Down
25 changes: 13 additions & 12 deletions pkg/engine/assert/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package assert

import (
"context"
"errors"
"fmt"
"reflect"

Expand All @@ -13,14 +14,14 @@ import (
"k8s.io/apimachinery/pkg/util/validation/field"
)

func Parse(ctx context.Context, path *field.Path, assertion any) (Assertion, error) {
func Parse(ctx context.Context, assertion any) (Assertion, error) {
switch reflectutils.GetKind(assertion) {
case reflect.Slice:
return parseSlice(ctx, path, assertion)
return parseSlice(ctx, assertion)

Check warning on line 20 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L20

Added line #L20 was not covered by tests
case reflect.Map:
return parseMap(ctx, path, assertion)
return parseMap(ctx, assertion)
default:
return parseScalar(ctx, path, assertion)
return parseScalar(ctx, assertion)

Check warning on line 24 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L24

Added line #L24 was not covered by tests
}
}

Expand All @@ -34,11 +35,11 @@ func (n node) assert(ctx context.Context, path *field.Path, value any, bindings
// parseSlice is the assertion represented by a slice.
// it first compares the length of the analysed resource with the length of the descendants.
// if lengths match all descendants are evaluated with their corresponding items.
func parseSlice(ctx context.Context, path *field.Path, assertion any) (node, error) {
func parseSlice(ctx context.Context, assertion any) (node, error) {

Check warning on line 38 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L38

Added line #L38 was not covered by tests
var assertions []Assertion
valueOf := reflect.ValueOf(assertion)
for i := 0; i < valueOf.Len(); i++ {
sub, err := Parse(ctx, path.Index(i), valueOf.Index(i).Interface())
sub, err := Parse(ctx, valueOf.Index(i).Interface())

Check warning on line 42 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L42

Added line #L42 was not covered by tests
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -70,7 +71,7 @@ func parseSlice(ctx context.Context, path *field.Path, assertion any) (node, err

// parseMap is the assertion represented by a map.
// it is responsible for projecting the analysed resource and passing the result to the descendant
func parseMap(ctx context.Context, path *field.Path, assertion any) (node, error) {
func parseMap(ctx context.Context, assertion any) (node, error) {
assertions := map[any]struct {
*expression
Assertion
Expand All @@ -79,7 +80,7 @@ func parseMap(ctx context.Context, path *field.Path, assertion any) (node, error
for iter.Next() {
key := iter.Key().Interface()
value := iter.Value().Interface()
assertion, err := Parse(ctx, path.Child(fmt.Sprint(key)), value)
assertion, err := Parse(ctx, value)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -158,21 +159,21 @@ func parseMap(ctx context.Context, path *field.Path, assertion any) (node, error
// parseScalar is the assertion represented by a leaf.
// it receives a value and compares it with an expected value.
// the expected value can be the result of an expression.
func parseScalar(ctx context.Context, path *field.Path, assertion any) (node, error) {
func parseScalar(ctx context.Context, assertion any) (node, error) {

Check warning on line 162 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L162

Added line #L162 was not covered by tests
expression := parseExpression(ctx, assertion)
// we only project if the expression uses the engine syntax
// this is to avoid the case where the value is a map and the RHS is a string
var project func(ctx context.Context, value any, bindings binding.Bindings, opts ...template.Option) (any, error)
if expression != nil && expression.engine != "" {
if expression.foreachName != "" {
return nil, field.Invalid(path, assertion, "foreach is not supported on the RHS")
return nil, errors.New("foreach is not supported on the RHS")

Check warning on line 169 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L169

Added line #L169 was not covered by tests
}
if expression.binding != "" {
return nil, field.Invalid(path, assertion, "binding is not supported on the RHS")
return nil, errors.New("binding is not supported on the RHS")

Check warning on line 172 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L172

Added line #L172 was not covered by tests
}
ast, err := expression.ast()
if err != nil {
return nil, field.InternalError(path, err)
return nil, err

Check warning on line 176 in pkg/engine/assert/parse.go

View check run for this annotation

Codecov / codecov/patch

pkg/engine/assert/parse.go#L176

Added line #L176 was not covered by tests
}
project = func(ctx context.Context, value any, bindings jpbinding.Bindings, opts ...template.Option) (any, error) {
return template.ExecuteAST(ctx, ast, value, bindings, opts...)
Expand Down
20 changes: 12 additions & 8 deletions pkg/matching/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert,
var fails []Result
path := path.Child("any")
for i, assertion := range match.Any {
parsed, err := assert.Parse(ctx, path.Index(i).Child("check"), assertion.Check.Raw())
path := path.Index(i).Child("check")
parsed, err := assertion.Check.Assertion()
if err != nil {
return fails, err
}
checkFails, err := assert.Assert(ctx, path.Index(i).Child("check"), parsed, actual, bindings, opts...)
checkFails, err := assert.Assert(ctx, path, parsed, actual, bindings, opts...)
if err != nil {
return fails, err
}
Expand All @@ -76,11 +77,12 @@ func MatchAssert(ctx context.Context, path *field.Path, match *v1alpha1.Assert,
var fails []Result
path := path.Child("all")
for i, assertion := range match.All {
parsed, err := assert.Parse(ctx, path.Index(i).Child("check"), assertion.Check.Raw())
path := path.Index(i).Child("check")
parsed, err := assertion.Check.Assertion()
if err != nil {
return fails, err
}
checkFails, err := assert.Assert(ctx, path.Index(i).Child("check"), parsed, actual, bindings, opts...)
checkFails, err := assert.Assert(ctx, path, parsed, actual, bindings, opts...)
if err != nil {
return fails, err
}
Expand Down Expand Up @@ -126,11 +128,12 @@ func Match(ctx context.Context, path *field.Path, match *v1alpha1.Match, actual
func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.AssertionTree, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
var errs field.ErrorList
for i, assertion := range assertions {
parsed, err := assert.Parse(ctx, path.Index(i), assertion.Raw())
path := path.Index(i)
assertion, err := assertion.Assertion()
if err != nil {
return errs, err
}
_errs, err := assert.Assert(ctx, path.Index(i), parsed, actual, bindings, opts...)
_errs, err := assert.Assert(ctx, path, assertion, actual, bindings, opts...)
if err != nil {
return errs, err
}
Expand All @@ -145,11 +148,12 @@ func MatchAny(ctx context.Context, path *field.Path, assertions []v1alpha1.Asser
func MatchAll(ctx context.Context, path *field.Path, assertions []v1alpha1.AssertionTree, actual any, bindings binding.Bindings, opts ...template.Option) (field.ErrorList, error) {
var errs field.ErrorList
for i, assertion := range assertions {
parsed, err := assert.Parse(ctx, path.Index(i), assertion.Raw())
path := path.Index(i)
assertion, err := assertion.Assertion()
if err != nil {
return errs, err
}
_errs, err := assert.Assert(ctx, path.Index(i), parsed, actual, bindings, opts...)
_errs, err := assert.Assert(ctx, path, assertion, actual, bindings, opts...)
if err != nil {
return errs, err
}
Expand Down
Loading

0 comments on commit 9f3325d

Please sign in to comment.