From afb7ee69d121aa05d5699fc2d99bd89402181c9e Mon Sep 17 00:00:00 2001 From: Sander van Harmelen Date: Fri, 13 Dec 2024 18:06:08 +0100 Subject: [PATCH] Fix a logical error in the Loader.Load method --- dataloader.go | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/dataloader.go b/dataloader.go index 3c3a5ce..b651cc4 100644 --- a/dataloader.go +++ b/dataloader.go @@ -209,14 +209,31 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { value *Result[V] } - // lock to prevent duplicate keys coming in before item has been added to cache. + // We need to lock both the batchLock and cacheLock because the batcher can + // reset the cache when either the batchCap or the wait time is reached. + // + // When we would only lock the cacheLock while doing l.cache.Get and/or + // l.cache.Set, it could be that the batcher resets the cache after those + // operations have finished but before the new request (if any) is send to the + // batcher. + // + // In that case it is no longer guaranteed that the keys passed to the BatchFunc + // function are unique as the cache has been reset so if the same key is + // requested again before the new batcher is started, the same key will be + // send to the batcher again causing unexpected behavior in the BatchFunc. + l.batchLock.Lock() l.cacheLock.Lock() + if v, ok := l.cache.Get(ctx, key); ok { + l.cacheLock.Unlock() + l.batchLock.Unlock() defer finish(v) - defer l.cacheLock.Unlock() return v } + defer l.batchLock.Unlock() + defer l.cacheLock.Unlock() + thunk := func() (V, error) { result.mu.RLock() resultNotSet := result.value == nil @@ -240,13 +257,11 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { defer finish(thunk) l.cache.Set(ctx, key, thunk) - l.cacheLock.Unlock() // this is sent to batch fn. It contains the key and the channel to return // the result on req := &batchRequest[K, V]{key, c} - l.batchLock.Lock() // start the batch window if it hasn't already started. if l.curBatcher == nil { l.curBatcher = l.newBatcher(l.silent, l.tracer) @@ -274,7 +289,6 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { l.reset() } } - l.batchLock.Unlock() return thunk }