From ac834ce67862b7c4ed8bdbe86e664572f820c245 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Tue, 9 Oct 2018 19:05:06 -0400 Subject: [PATCH] Gracefully handle TERM signals (#206) Add a -term_timeout flag, when receiving the TERM singal proxy will wait up to `term_timeout` for existing connections to close. --- cmd/cloud_sql_proxy/cloud_sql_proxy.go | 24 +++++++++++++++++-- proxy/proxy/client.go | 33 ++++++++++++++++++-------- proxy/proxy/client_test.go | 32 +++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/cmd/cloud_sql_proxy/cloud_sql_proxy.go b/cmd/cloud_sql_proxy/cloud_sql_proxy.go index e834bb3d7..53ad53386 100644 --- a/cmd/cloud_sql_proxy/cloud_sql_proxy.go +++ b/cmd/cloud_sql_proxy/cloud_sql_proxy.go @@ -27,9 +27,11 @@ import ( "log" "net/http" "os" + "os/signal" "path/filepath" "strings" "sync" + "syscall" "time" "github.com/GoogleCloudPlatform/cloudsql-proxy/logging" @@ -77,6 +79,7 @@ can be removed automatically by this program.`) // Settings for limits maxConnections = flag.Uint64("max_connections", 0, `If provided, the maximum number of connections to establish before refusing new connections. Defaults to 0 (no limit)`) fdRlimit = flag.Uint64("fd_rlimit", limits.ExpectedFDs, `Sets the rlimit on the number of open file descriptors for the proxy to the provided value. If set to zero, disables attempts to set the rlimit. Defaults to a value which can support 4K connections to one instance`) + termTimeout = flag.Duration("term_timeout", 0, "When set, the proxy will wait for existing connections to close before terminating. Any connections that haven't closed after the timeout will be dropped") // Settings for authentication. token = flag.String("token", "", "When set, the proxy uses this Bearer token for authorization.") @@ -496,7 +499,7 @@ func main() { } logging.Infof("Ready for new connections") - (&proxy.Client{ + proxyClient := &proxy.Client{ Port: port, MaxConnections: *maxConnections, Certs: certs.NewCertSourceOpts(client, certs.RemoteOpts{ @@ -507,5 +510,22 @@ func main() { }), Conns: connset, RefreshCfgThrottle: refreshCfgThrottle, - }).Run(connSrc) + } + + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + + go func() { + <-signals + logging.Infof("Received TERM signal. Waiting up to %s before terminating.", *termTimeout) + + err := proxyClient.Shutdown(*termTimeout) + + if err == nil { + os.Exit(0) + } + os.Exit(2) + }() + + proxyClient.Run(connSrc) } diff --git a/proxy/proxy/client.go b/proxy/proxy/client.go index a57a37a62..35c4dca42 100644 --- a/proxy/proxy/client.go +++ b/proxy/proxy/client.go @@ -111,18 +111,15 @@ func (c *Client) Run(connSrc <-chan Conn) { } func (c *Client) handleConn(conn Conn) { - // Track connections count only if a maximum connections limit is set to avoid useless overhead - if c.MaxConnections > 0 { - active := atomic.AddUint64(&c.ConnectionsCounter, 1) + active := atomic.AddUint64(&c.ConnectionsCounter, 1) - // Deferred decrement of ConnectionsCounter upon connection closing - defer atomic.AddUint64(&c.ConnectionsCounter, ^uint64(0)) + // Deferred decrement of ConnectionsCounter upon connection closing + defer atomic.AddUint64(&c.ConnectionsCounter, ^uint64(0)) - if active > c.MaxConnections { - logging.Errorf("too many open connections (max %d)", c.MaxConnections) - conn.Conn.Close() - return - } + if c.MaxConnections > 0 && active > c.MaxConnections { + logging.Errorf("too many open connections (max %d)", c.MaxConnections) + conn.Conn.Close() + return } server, err := c.Dial(conn.Instance) @@ -323,3 +320,19 @@ func NewConnSrc(instance string, l net.Listener) <-chan Conn { }() return ch } + +// Shutdown waits up to a given amount of time for all active connections to +// close. Returns an error if there are still active connections after waiting +// for the whole length of the timeout. +func (c *Client) Shutdown(termTimeout time.Duration) error { + termTime := time.Now().Add(termTimeout) + for termTime.After(time.Now()) && atomic.LoadUint64(&c.ConnectionsCounter) > 0 { + time.Sleep(1) + } + + active := atomic.LoadUint64(&c.ConnectionsCounter) + if active == 0 { + return nil + } + return fmt.Errorf("%d active connections still exist after waiting for %v", active, termTimeout) +} diff --git a/proxy/proxy/client_test.go b/proxy/proxy/client_test.go index fc0b35f87..a59ad1e1c 100644 --- a/proxy/proxy/client_test.go +++ b/proxy/proxy/client_test.go @@ -183,3 +183,35 @@ func TestMaximumConnectionsCount(t *testing.T) { t.Errorf("client should have dialed exactly the maximum of %d connections (%d connections, %d dials)", maxConnections, numConnections, dials) } } + +func TestShutdownTerminatesEarly(t *testing.T) { + b := &fakeCerts{} + c := &Client{ + Certs: &blockingCertSource{ + map[string]*fakeCerts{ + instance: b, + }}, + Dialer: func(string, string) (net.Conn, error) { + return nil, nil + }, + } + + shutdown := make(chan bool, 1) + go func() { + c.Shutdown(1) + shutdown <- true + }() + + shutdownFinished := false + + // In case the code is actually broken and the client doesn't shut down quickly, don't cause the test to hang until it times out. + select { + case <-time.After(100 * time.Millisecond): + case shutdownFinished = <-shutdown: + } + + if !shutdownFinished { + t.Errorf("shutdown should have completed quickly because there are no active connections") + } + +}