From 3f5044452c7afcad84e40a360136b85c3ce73a01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Espino=20Garc=C3=ADa?= Date: Tue, 17 Dec 2024 14:14:06 +0100 Subject: [PATCH] Add android retry --- server/android_notification_server.go | 75 +++++++++++++++++++++++---- server/android_notification_test.go | 6 +-- server/server.go | 2 +- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/server/android_notification_server.go b/server/android_notification_server.go index 66cbf5c..3b5cbca 100644 --- a/server/android_notification_server.go +++ b/server/android_notification_server.go @@ -41,6 +41,7 @@ type AndroidNotificationServer struct { AndroidPushSettings AndroidPushSettings client *messaging.Client sendTimeout time.Duration + retryTimeout time.Duration } // serviceAccount contains a subset of the fields in service-account.json. @@ -54,12 +55,13 @@ type serviceAccount struct { TokenURI string `json:"token_uri"` } -func NewAndroidNotificationServer(settings AndroidPushSettings, logger *Logger, metrics *metrics, sendTimeoutSecs int) *AndroidNotificationServer { +func NewAndroidNotificationServer(settings AndroidPushSettings, logger *Logger, metrics *metrics, sendTimeoutSecs int, retryTimeoutSecs int) *AndroidNotificationServer { return &AndroidNotificationServer{ AndroidPushSettings: settings, metrics: metrics, logger: logger, sendTimeout: time.Duration(sendTimeoutSecs) * time.Second, + retryTimeout: time.Duration(retryTimeoutSecs) * time.Second, } } @@ -166,16 +168,8 @@ func (me *AndroidNotificationServer) SendNotification(msg *PushNotification) Pus }, } - ctx, cancel := context.WithTimeout(context.Background(), me.sendTimeout) - defer cancel() - me.logger.Infof("Sending android push notification for device=%v type=%v ackId=%v", me.AndroidPushSettings.Type, msg.Type, msg.AckID) - - start := time.Now() - _, err := me.client.Send(ctx, fcmMsg) - if me.metrics != nil { - me.metrics.observerNotificationResponse(PushNotifyAndroid, time.Since(start).Seconds()) - } + err := me.SendNotificationWithRetry(fcmMsg) if err != nil { errorCode, hasStatusCode := getErrorCode(err) @@ -233,6 +227,67 @@ func (me *AndroidNotificationServer) SendNotification(msg *PushNotification) Pus return NewOkPushResponse() } +func (me *AndroidNotificationServer) SendNotificationWithRetry(fcmMsg *messaging.Message) error { + var err error + waitTime := time.Second + + // Keep a general context to make sure the whole retry + // doesn't take longer than the timeout. + generalContext, cancelGeneralContext := context.WithTimeout(context.Background(), me.sendTimeout) + defer cancelGeneralContext() + + for retries := 0; retries < MAX_RETRIES; retries++ { + start := time.Now() + + retryContext, cancelRetryContext := context.WithTimeout(generalContext, me.retryTimeout) + defer cancelRetryContext() + _, err := me.client.Send(retryContext, fcmMsg) + if me.metrics != nil { + me.metrics.observerNotificationResponse(PushNotifyApple, time.Since(start).Seconds()) + } + + if err == nil { + break + } + + if !isRetryable(err) { + break + } + + me.logger.Errorf("Failed to send android push did=%v retry=%v error=%v", fcmMsg.Token, retries, err) + + if retries == MAX_RETRIES-1 { + me.logger.Errorf("Max retries reached did=%v", fcmMsg.Token) + break + } + + select { + case <-generalContext.Done(): + case <-time.After(waitTime): + } + + if generalContext.Err() != nil { + me.logger.Infof("Not retrying because context error did=%v retry=%v error=%v", fcmMsg.Token, retries, generalContext.Err()) + err = generalContext.Err() + break + } + + waitTime *= 2 + } + + return err +} + +func isRetryable(err error) bool { + // We retry the errors based on https://firebase.google.com/docs/cloud-messaging/http-server-ref + return messaging.IsInternal(err) || + messaging.IsQuotaExceeded(err) + + // messaging.IsUnavailable is retried by the default retry config in + // firebase.google.com/go/v4@v4.14.0/internal/http_client.go + // messaging.IsUnavailable(err) +} + func getErrorCode(err error) (string, bool) { if err == nil { return "", false diff --git a/server/android_notification_test.go b/server/android_notification_test.go index 510035e..83f7a67 100644 --- a/server/android_notification_test.go +++ b/server/android_notification_test.go @@ -23,7 +23,7 @@ func TestAndroidInitialize(t *testing.T) { // Verify error for no service file pushSettings := AndroidPushSettings{} cfg.AndroidPushSettings[0] = pushSettings - require.Error(t, NewAndroidNotificationServer(cfg.AndroidPushSettings[0], logger, nil, cfg.SendTimeoutSec).Initialize()) + require.Error(t, NewAndroidNotificationServer(cfg.AndroidPushSettings[0], logger, nil, cfg.SendTimeoutSec, cfg.RetryTimeoutSec).Initialize()) f, err := os.CreateTemp("", "example") require.NoError(t, err) @@ -34,7 +34,7 @@ func TestAndroidInitialize(t *testing.T) { // Verify error for bad JSON _, err = f.Write([]byte("badJSON")) require.NoError(t, err) - require.Error(t, NewAndroidNotificationServer(cfg.AndroidPushSettings[0], logger, nil, cfg.SendTimeoutSec).Initialize()) + require.Error(t, NewAndroidNotificationServer(cfg.AndroidPushSettings[0], logger, nil, cfg.SendTimeoutSec, cfg.RetryTimeoutSec).Initialize()) require.NoError(t, f.Truncate(0)) _, err = f.Seek(0, 0) @@ -46,7 +46,7 @@ func TestAndroidInitialize(t *testing.T) { ProjectID: "sample", })) require.NoError(t, f.Sync()) - require.NoError(t, NewAndroidNotificationServer(cfg.AndroidPushSettings[0], logger, nil, cfg.SendTimeoutSec).Initialize()) + require.NoError(t, NewAndroidNotificationServer(cfg.AndroidPushSettings[0], logger, nil, cfg.SendTimeoutSec, cfg.RetryTimeoutSec).Initialize()) require.NoError(t, f.Close()) } diff --git a/server/server.go b/server/server.go index 986be48..8fde611 100644 --- a/server/server.go +++ b/server/server.go @@ -80,7 +80,7 @@ func (s *Server) Start() { } for _, settings := range s.cfg.AndroidPushSettings { - server := NewAndroidNotificationServer(settings, s.logger, m, s.cfg.SendTimeoutSec) + server := NewAndroidNotificationServer(settings, s.logger, m, s.cfg.SendTimeoutSec, s.cfg.RetryTimeoutSec) err := server.Initialize() if err != nil { s.logger.Errorf("Failed to initialize client: %v", err)