From 6bbc8d5b3a74905157bd9fe430711c5beb6f2e03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nelson=20Gon=C3=A7alves?= Date: Mon, 20 May 2024 22:50:57 +0100 Subject: [PATCH] Support cache skipping for `Load()` calls that throw `SkipCacheError` --- dataloader.go | 17 ++++++++++++- dataloader_test.go | 63 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/dataloader.go b/dataloader.go index 3c3a5ce..b162ba3 100644 --- a/dataloader.go +++ b/dataloader.go @@ -60,6 +60,20 @@ func (p *PanicErrorWrapper) Error() string { return p.panicError.Error() } +// SkipCacheError wraps the error interface. +// The cache should not store SkipCacheErrors. +type SkipCacheError struct { + err error +} + +func (s *SkipCacheError) Error() string { + return s.err.Error() +} + +func NewSkipCacheError(err error) *SkipCacheError { + return &SkipCacheError{err: err} +} + // Loader implements the dataloader.Interface. type Loader[K comparable, V any] struct { // the batch function to be used by this loader @@ -232,7 +246,8 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { result.mu.RLock() defer result.mu.RUnlock() var ev *PanicErrorWrapper - if result.value.Error != nil && errors.As(result.value.Error, &ev) { + var es *SkipCacheError + if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){ l.Clear(ctx, key) } return result.value.Data, result.value.Error diff --git a/dataloader_test.go b/dataloader_test.go index 7acab3c..02e5888 100644 --- a/dataloader_test.go +++ b/dataloader_test.go @@ -79,6 +79,45 @@ func TestLoader(t *testing.T) { } }) + t.Run("test Load Method not caching results with errors of type SkipCacheError", func(t *testing.T) { + t.Parallel() + skipCacheLoader, loadCalls := SkipCacheErrorLoader(3, "1") + ctx := context.Background() + futures1 := skipCacheLoader.LoadMany(ctx, []string{"1", "2", "3"}) + _, errs1 := futures1() + var errCount int = 0 + var nilCount int = 0 + for _, err := range errs1 { + if err == nil { + nilCount++ + } else { + errCount++ + } + } + if errCount != 1 { + t.Error("Expected an error on only key \"1\"") + } + + if nilCount != 2 { + t.Error("Expected the other errors to be nil") + } + + futures2 := skipCacheLoader.LoadMany(ctx, []string{"2", "3", "1"}) + _, errs2 := futures2() + // There should be no errors in the second batch, as the only key that was not cached + // this time around will not throw an error + if errs2 != nil { + t.Error("Expected LoadMany() to return nil error slice when no errors occurred") + } + + calls := (*loadCalls)[1] + expected := []string{"1"} + + if !reflect.DeepEqual(calls, expected) { + t.Errorf("Expected load calls %#v, got %#v", expected, calls) + } + }) + t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) { t.Parallel() defer func() { @@ -622,6 +661,30 @@ func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) { return errorCacheLoader, &loadCalls } +func SkipCacheErrorLoader[K comparable](max int, onceErrorKey K) (*Loader[K, K], *[][]K) { + var mu sync.Mutex + var loadCalls [][]K + errorThrown := false + skipCacheErrorLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { + var results []*Result[K] + mu.Lock() + loadCalls = append(loadCalls, keys) + mu.Unlock() + // return a non cacheable error for the first occurence of onceErrorKey + for _, k := range keys { + if !errorThrown && k == onceErrorKey { + results = append(results, &Result[K]{k, NewSkipCacheError(fmt.Errorf("non cacheable error"))}) + errorThrown = true + } else { + results = append(results, &Result[K]{k, nil}) + } + } + + return results + }, WithBatchCapacity[K, K](max)) + return skipCacheErrorLoader, &loadCalls +} + func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) { var mu sync.Mutex var loadCalls [][]K