From b5fd43b8dc5e010e65bdc7198a14293eb85187ce Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Mon, 24 Apr 2023 16:47:37 -0700 Subject: [PATCH] Fix bug of zero credit in rate limiter --- utils/rate_limiter.go | 8 ++++++-- utils/rate_limiter_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/utils/rate_limiter.go b/utils/rate_limiter.go index bf2f1316..cbb83ab3 100644 --- a/utils/rate_limiter.go +++ b/utils/rate_limiter.go @@ -21,7 +21,7 @@ import ( // RateLimiter is a filter used to check if a message that is worth itemCost units is within the rate limits. // -// TODO (breaking change) remove this interface in favor of public struct below +// # TODO (breaking change) remove this interface in favor of public struct below // // Deprecated, use ReconfigurableRateLimiter. type RateLimiter interface { @@ -55,9 +55,13 @@ type ReconfigurableRateLimiter struct { // NewRateLimiter creates a new ReconfigurableRateLimiter. func NewRateLimiter(creditsPerSecond, maxBalance float64) *ReconfigurableRateLimiter { + balance := maxBalance + if creditsPerSecond == 0 { + balance = 0 + } return &ReconfigurableRateLimiter{ creditsPerSecond: creditsPerSecond, - balance: maxBalance, + balance: balance, maxBalance: maxBalance, lastTick: time.Now(), timeNow: time.Now, diff --git a/utils/rate_limiter_test.go b/utils/rate_limiter_test.go index 63b61713..bbd13a92 100644 --- a/utils/rate_limiter_test.go +++ b/utils/rate_limiter_test.go @@ -73,6 +73,22 @@ func TestRateLimiterMaxBalance(t *testing.T) { assert.False(t, rl.CheckCredit(1.0)) } +func TestRateLimiterZeroCreditsPerSecond(t *testing.T) { + rl := NewRateLimiter(0, 1.0) + // stop time + ts := time.Now() + rl.lastTick = ts + rl.timeNow = func() time.Time { + return ts + } + assert.False(t, rl.CheckCredit(1.0), "on initialization, should not have any credit for 1 message") + + rl.timeNow = func() time.Time { + return ts.Add(time.Second * 20) + } + assert.False(t, rl.CheckCredit(1.0)) +} + func TestRateLimiterReconfigure(t *testing.T) { rl := NewRateLimiter(1, 1.0) assertBalance := func(expected float64) {