diff --git a/backend/postgres/README.md b/backend/postgres/README.md index be3628506..44a4770ea 100644 --- a/backend/postgres/README.md +++ b/backend/postgres/README.md @@ -15,32 +15,12 @@ Configuration for the PostgreSQL connection pool of the microservice. * postgres user to access the database * `POSTGRES_DB` default: `postgres` * database to access -* `POSTGRES_MAX_RETRIES` default: `5` - * Maximum number of retries before giving up -* `POSTGRES_RETRY_STATEMENT_TIMEOUT` default: `false` - * Whether to retry queries cancelled because of statement_timeout -* `POSTGRES_MIN_RETRY_BACKOFF` default: `250ms` - * Minimum backoff between each retry -* `POSTGRES_MAX_RETRY_BACKOFF` default: `4s` - * Maximum backoff between each retry * `POSTGRES_DIAL_TIMEOUT` default: `5s` * Dial timeout for establishing new connections * `POSTGRES_READ_TIMEOUT` default: `30s` * Timeout for socket reads. If reached, commands will fail with a timeout instead of blocking * `POSTGRES_WRITE_TIMEOUT` default: `30s` * Timeout for socket writes. If reached, commands will fail with a timeout instead of blocking. -* `POSTGRES_POOL_SIZE` default: `100` - * Maximum number of socket connections -* `POSTGRES_MIN_IDLE_CONNECTIONS` default: `10` - * Minimum number of idle connections which is useful when establishing new connection is slow -* `POSTGRES_MAX_CONN_AGE` default: `30m` - * Connection age at which client retires (closes) the connection -* `POSTGRES_POOL_TIMEOUT` default: `31s` - * Time for which client waits for free connection if all connections are busy before returning an error -* `POSTGRES_IDLE_TIMEOUT` default: `5m` - * Amount of time after which client closes idle connections -* `POSTGRES_IDLE_CHECK_FREQUENCY` default: `1m` - * Frequency of idle checks made by idle connections reaper * `POSTGRES_HEALTH_CHECK_TABLE_NAME` default: `healthcheck` * Name of the Table that is created to try if database is writeable * `POSTGRES_HEALTH_CHECK_RESULT_TTL` default: `10s` @@ -53,7 +33,6 @@ Prometheus metrics exposed. * `pace_postgres_query_total{database}` Collects stats about the number of postgres queries made * `pace_postgres_query_failed{database}` Collects stats about the number of postgres queries failed * `pace_postgres_query_duration_seconds{database}` Collects performance metrics for each postgres query -* `pace_postgres_query_rows_total{database}` Collects stats about the number of rows returned by a postgres query * `pace_postgres_query_affected_total{database}` Collects stats about the number of rows affected by a postgres query * `pace_postgres_connection_pool_hits{database}` Collects number of times free connection was found in the pool * `pace_postgres_connection_pool_misses{database}` Collects number of times free connection was NOT found in the pool diff --git a/backend/postgres/errors.go b/backend/postgres/errors.go index b72d5ee20..7aec7250a 100644 --- a/backend/postgres/errors.go +++ b/backend/postgres/errors.go @@ -7,32 +7,34 @@ import ( "io" "net" - "github.com/go-pg/pg" + "github.com/uptrace/bun/driver/pgdriver" ) var ErrNotUnique = errors.New("not unique") func IsErrConnectionFailed(err error) bool { - // go-pg has this check internally for network errors + // bun has this check internally for network errors if errors.Is(err, io.EOF) { return true } - // go-pg has this check internally for network errors + // bun has this check internally for network errors _, ok := err.(net.Error) if ok { return true } - // go-pg has similar check for integrity violation issues, here we check network issues - pgErr, ok := err.(pg.Error) - if ok { + // bun has similar check for integrity violation issues, here we check network issues + var pgErr pgdriver.Error + + if errors.As(err, &pgErr) { code := pgErr.Field('C') // We check on error codes of Class 08 — Connection Exception. // https://www.postgresql.org/docs/10/errcodes-appendix.html - if code[0:2] == "08" { + if len(code) > 2 && code[0:2] == "08" { return true } } + return false } diff --git a/backend/postgres/errors_test.go b/backend/postgres/errors_test.go index 9df6cce7c..a5a3102a9 100644 --- a/backend/postgres/errors_test.go +++ b/backend/postgres/errors_test.go @@ -1,12 +1,12 @@ package postgres_test import ( + "context" "errors" "fmt" "io" "testing" - "github.com/go-pg/pg" "github.com/stretchr/testify/require" pbpostgres "github.com/pace/bricks/backend/postgres" @@ -19,13 +19,10 @@ func TestIsErrConnectionFailed(t *testing.T) { }) t.Run("connection failed (net.Error)", func(t *testing.T) { - db := pbpostgres.CustomConnectionPool(&pg.Options{}) // invalid connection - _, err := db.Exec("") - require.True(t, pbpostgres.IsErrConnectionFailed(err)) - }) + ctx := context.Background() - t.Run("connection failed (pg.Error)", func(t *testing.T) { - err := error(mockPGError{m: map[byte]string{'C': "08000"}}) + db := pbpostgres.NewDB(ctx, pbpostgres.WithHost("foobar")) // invalid connection + _, err := db.Exec("") require.True(t, pbpostgres.IsErrConnectionFailed(err)) }) @@ -34,11 +31,3 @@ func TestIsErrConnectionFailed(t *testing.T) { require.False(t, pbpostgres.IsErrConnectionFailed(err)) }) } - -type mockPGError struct { - m map[byte]string -} - -func (err mockPGError) Field(k byte) string { return err.m[k] } -func (err mockPGError) IntegrityViolation() bool { return false } -func (err mockPGError) Error() string { return fmt.Sprintf("%+v", err.m) } diff --git a/backend/postgres/health.go b/backend/postgres/health.go new file mode 100644 index 000000000..66c28bd6d --- /dev/null +++ b/backend/postgres/health.go @@ -0,0 +1,91 @@ +// Copyright © 2024 by PACE Telematics GmbH. All rights reserved. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/pace/bricks/maintenance/health/servicehealthcheck" + "github.com/uptrace/bun" +) + +type queryExecutor interface { + Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) +} + +// HealthCheck checks the state of a postgres connection. It must not be changed +// after it was registered as a health check. +type HealthCheck struct { + state servicehealthcheck.ConnectionState + + createTableQueryExecutor queryExecutor + deleteQueryExecutor queryExecutor + dropTableQueryExecutor queryExecutor + insertQueryExecutor queryExecutor + selectQueryExecutor queryExecutor +} + +type healthcheck struct { + bun.BaseModel + + OK bool `bun:"column:ok"` +} + +// NewHealthCheck creates a new HealthCheck instance. +func NewHealthCheck(db *bun.DB) *HealthCheck { + return &HealthCheck{ + createTableQueryExecutor: db.NewCreateTable().Model((*healthcheck)(nil)).ModelTableExpr(cfg.HealthCheckTableName).IfNotExists(), + deleteQueryExecutor: db.NewDelete().ModelTableExpr(cfg.HealthCheckTableName).Where("TRUE"), + dropTableQueryExecutor: db.NewDropTable().ModelTableExpr(cfg.HealthCheckTableName).IfExists(), + insertQueryExecutor: db.NewInsert().ModelTableExpr(cfg.HealthCheckTableName).Model(&healthcheck{OK: true}), + selectQueryExecutor: db.NewRaw("SELECT 1;"), + } +} + +// Init initializes the test table +func (h *HealthCheck) Init(ctx context.Context) error { + _, err := h.createTableQueryExecutor.Exec(ctx) + return err +} + +// HealthCheck performs the read test on the database. If enabled, it performs a +// write test as well. +func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.HealthCheckResult { + if time.Since(h.state.LastChecked()) <= cfg.HealthCheckResultTTL { + // the last result of the Health Check is still not outdated + return h.state.GetState() + } + + // Readcheck + if _, err := h.selectQueryExecutor.Exec(ctx); err != nil { + h.state.SetErrorState(err) + return h.state.GetState() + } + + // writecheck - add Data to configured Table + if _, err := h.insertQueryExecutor.Exec(ctx); err != nil { + h.state.SetErrorState(err) + return h.state.GetState() + } + + // and while we're at it, check delete as well (so as not to clutter the database + // because UPSERT is impractical here + if _, err := h.deleteQueryExecutor.Exec(ctx); err != nil { + h.state.SetErrorState(err) + return h.state.GetState() + } + + // If no error occurred set the State of this Health Check to healthy + h.state.SetHealthy() + + return h.state.GetState() +} + +// CleanUp drops the test table. +func (h *HealthCheck) CleanUp(ctx context.Context) error { + _, err := h.dropTableQueryExecutor.Exec(ctx) + + return err +} diff --git a/backend/postgres/health_postgres.go b/backend/postgres/health_postgres.go deleted file mode 100644 index 5b5dfe866..000000000 --- a/backend/postgres/health_postgres.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright © 2019 by PACE Telematics GmbH. All rights reserved. - -package postgres - -import ( - "context" - "time" - - "github.com/go-pg/pg/orm" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" -) - -// HealthCheck checks the state of a postgres connection. It must not be changed -// after it was registered as a health check. -type HealthCheck struct { - state servicehealthcheck.ConnectionState - Pool postgresQueryExecutor -} - -type postgresQueryExecutor interface { - Exec(ctx context.Context, query interface{}, params ...interface{}) (res orm.Result, err error) -} - -// Init initializes the test table -func (h *HealthCheck) Init(ctx context.Context) error { - _, errWrite := h.Pool.Exec(ctx, `CREATE TABLE IF NOT EXISTS `+cfg.HealthCheckTableName+`(ok boolean);`) - return errWrite -} - -// HealthCheck performs the read test on the database. If enabled, it performs a -// write test as well. -func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.HealthCheckResult { - if time.Since(h.state.LastChecked()) <= cfg.HealthCheckResultTTL { - // the last result of the Health Check is still not outdated - return h.state.GetState() - } - - // Readcheck - if _, err := h.Pool.Exec(ctx, `SELECT 1;`); err != nil { - h.state.SetErrorState(err) - return h.state.GetState() - } - // writecheck - add Data to configured Table - _, err := h.Pool.Exec(ctx, "INSERT INTO "+cfg.HealthCheckTableName+"(ok) VALUES (true);") - if err != nil { - h.state.SetErrorState(err) - return h.state.GetState() - } - // and while we're at it, check delete as well (so as not to clutter the database - // because UPSERT is impractical here - _, err = h.Pool.Exec(ctx, "DELETE FROM "+cfg.HealthCheckTableName+";") - if err != nil { - h.state.SetErrorState(err) - return h.state.GetState() - } - // If no error occurred set the State of this Health Check to healthy - h.state.SetHealthy() - return h.state.GetState() -} - -// CleanUp drops the test table. -func (h *HealthCheck) CleanUp(ctx context.Context) error { - _, err := h.Pool.Exec(ctx, "DROP TABLE IF EXISTS "+cfg.HealthCheckTableName) - return err -} diff --git a/backend/postgres/health_postgres_test.go b/backend/postgres/health_test.go similarity index 87% rename from backend/postgres/health_postgres_test.go rename to backend/postgres/health_test.go index 2bb978fd6..4e1de50ac 100644 --- a/backend/postgres/health_postgres_test.go +++ b/backend/postgres/health_test.go @@ -4,6 +4,7 @@ package postgres import ( "context" + "database/sql" "io" "net/http" "net/http/httptest" @@ -11,12 +12,12 @@ import ( "testing" "time" - "github.com/go-pg/pg/orm" + "github.com/stretchr/testify/require" + http2 "github.com/pace/bricks/http" "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" - "github.com/stretchr/testify/require" ) func setup() *http.Response { @@ -52,7 +53,7 @@ type testPool struct { err error } -func (t *testPool) Exec(ctx context.Context, query interface{}, params ...interface{}) (res orm.Result, err error) { +func (t *testPool) Exec(ctx context.Context, dest ...any) (sql.Result, error) { return nil, t.err } @@ -63,16 +64,25 @@ func TestHealthCheckCaching(t *testing.T) { cfg.HealthCheckResultTTL = time.Minute requiredErr := errors.New("TestHealthCheckCaching") pool := &testPool{err: requiredErr} - h := &HealthCheck{Pool: pool} + h := &HealthCheck{ + createTableQueryExecutor: pool, + deleteQueryExecutor: pool, + dropTableQueryExecutor: pool, + insertQueryExecutor: pool, + selectQueryExecutor: pool, + } + res := h.HealthCheck(ctx) // get the error for the first time require.Equal(t, servicehealthcheck.Err, res.State) require.Equal(t, "TestHealthCheckCaching", res.Msg) + res = h.HealthCheck(ctx) pool.err = nil // getting the cached error require.Equal(t, servicehealthcheck.Err, res.State) require.Equal(t, "TestHealthCheckCaching", res.Msg) + // Resetting the TTL to get a uncached result cfg.HealthCheckResultTTL = 0 res = h.HealthCheck(ctx) diff --git a/backend/postgres/hooks/logging.go b/backend/postgres/hooks/logging.go new file mode 100644 index 000000000..384b41541 --- /dev/null +++ b/backend/postgres/hooks/logging.go @@ -0,0 +1,87 @@ +// Copyright © 2024 by PACE Telematics GmbH. All rights reserved. + +package hooks + +import ( + "context" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/uptrace/bun" + + "github.com/pace/bricks/maintenance/log" +) + +type queryMode int + +const ( + readMode queryMode = iota + writeMode queryMode = iota +) + +type LoggingHook struct { + logReadQueries bool + logWriteQueries bool +} + +func NewLoggingHook(logRead bool, logWrite bool) *LoggingHook { + return &LoggingHook{ + logReadQueries: logRead, + logWriteQueries: logWrite, + } +} + +func (h *LoggingHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context { + return ctx +} + +func (h *LoggingHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) { + // we can only and should only perfom the following check if we have the information availaible + mode := determineQueryMode(event.Query) + + if mode == readMode && !h.logReadQueries { + return + } + + if mode == writeMode && !h.logWriteQueries { + return + } + + dur := float64(time.Since(event.StartTime)) / float64(time.Millisecond) + + // check if log context is given + var logger *zerolog.Logger + if ctx != nil { + logger = log.Ctx(ctx) + } else { + logger = log.Logger() + } + + // add general info + logEvent := logger.Debug(). + Float64("duration", dur). + Str("sentry:category", "postgres") + + // add error or result set info + if event.Err != nil { + logEvent = logEvent.Err(event.Err) + } else if event.Result != nil { + rowsAffected, err := event.Result.RowsAffected() + if err == nil { + logEvent = logEvent.Int64("affected", rowsAffected) + } + } + + logEvent.Msg(event.Query) +} + +// determineQueryMode is a poorman's attempt at checking whether the query is a read or write to the database. +// Feel free to improve. +func determineQueryMode(qry string) queryMode { + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(qry)), "select") { + return readMode + } + + return writeMode +} diff --git a/backend/postgres/hooks/metrics.go b/backend/postgres/hooks/metrics.go new file mode 100644 index 000000000..aa0cc50ba --- /dev/null +++ b/backend/postgres/hooks/metrics.go @@ -0,0 +1,83 @@ +// Copyright © 2024 by PACE Telematics GmbH. All rights reserved. + +package hooks + +import ( + "context" + "math" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/uptrace/bun" +) + +var ( + MetricQueryTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pace_postgres_query_total", + Help: "Collects stats about the number of postgres queries made", + }, + []string{"database"}, + ) + MetricQueryFailed = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pace_postgres_query_failed", + Help: "Collects stats about the number of postgres queries failed", + }, + []string{"database"}, + ) + MetricQueryDurationSeconds = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "pace_postgres_query_duration_seconds", + Help: "Collect performance metrics for each postgres query", + Buckets: []float64{.1, .25, .5, 1, 2.5, 5, 10, 60}, + }, + []string{"database"}, + ) + MetricQueryAffectedTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "pace_postgres_query_affected_total", + Help: "Collects stats about the number of rows affected by a postgres query", + }, + []string{"database"}, + ) +) + +type MetricsHook struct { + addr string + database string +} + +func NewMetricsHook(addr string, database string) *MetricsHook { + return &MetricsHook{ + addr: addr, + database: database, + } +} + +func (h *MetricsHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context { + return ctx +} + +func (h *MetricsHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) { + dur := float64(time.Since(event.StartTime)) / float64(time.Millisecond) + + labels := prometheus.Labels{ + "database": h.addr + "/" + h.database, + } + + MetricQueryTotal.With(labels).Inc() + + if event.Err != nil { + MetricQueryFailed.With(labels).Inc() + } else if event.Result != nil { + r := event.Result + + rowsAffected, err := r.RowsAffected() + if err == nil { + MetricQueryAffectedTotal.With(labels).Add(math.Max(0, float64(rowsAffected))) + } + } + + MetricQueryDurationSeconds.With(labels).Observe(dur) +} diff --git a/backend/postgres/hooks/tracing.go b/backend/postgres/hooks/tracing.go new file mode 100644 index 000000000..7b5549c0c --- /dev/null +++ b/backend/postgres/hooks/tracing.go @@ -0,0 +1,60 @@ +// Copyright © 2024 by PACE Telematics GmbH. All rights reserved. + +package hooks + +import ( + "context" + "regexp" + "strings" + + "github.com/opentracing/opentracing-go" + olog "github.com/opentracing/opentracing-go/log" + "github.com/uptrace/bun" +) + +var ( + reQueryType = regexp.MustCompile(`(\s)`) + reQueryTypeCleanup = regexp.MustCompile(`(?m)(\s+|\n)`) +) + +type TracingHook struct{} + +func (h *TracingHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context { + return ctx +} + +func (h *TracingHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) { + span, _ := opentracing.StartSpanFromContext(ctx, "sql: "+getQueryType(event.Query), + opentracing.StartTime(event.StartTime)) + defer span.Finish() + + span.SetTag("db.system", "postgres") + + fields := []olog.Field{ + olog.String("query", event.Query), + } + + // add error or result set info + if event.Err != nil { + fields = append(fields, olog.Error(event.Err)) + } else if event.Result != nil { + rowsAffected, err := event.Result.RowsAffected() + if err == nil { + fields = append(fields, olog.Int64("affected", rowsAffected)) + } + } + + span.LogFields(fields...) +} + +func getQueryType(s string) string { + s = reQueryTypeCleanup.ReplaceAllString(s, " ") + s = strings.TrimSpace(s) + + p := reQueryType.FindStringIndex(s) + if len(p) > 0 { + return strings.ToUpper(s[:p[0]]) + } + + return strings.ToUpper(s) +} diff --git a/backend/postgres/metrics.go b/backend/postgres/metrics.go deleted file mode 100644 index 69766aaf3..000000000 --- a/backend/postgres/metrics.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright © 2019 by PACE Telematics GmbH. All rights reserved. - -package postgres - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/go-pg/pg" - "github.com/prometheus/client_golang/prometheus" -) - -// ConnectionPoolMetrics is the metrics collector for postgres connection pools -// (pace_postgres_connection_pool_*). It is capable of running an observer that -// periodically gathers those stats. -type ConnectionPoolMetrics struct { - poolMetrics map[string]struct{} - poolMetricsMx sync.Mutex - - hits *prometheus.CounterVec - misses *prometheus.CounterVec - timeouts *prometheus.CounterVec - totalConns *prometheus.GaugeVec - idleConns *prometheus.GaugeVec - staleConns *prometheus.GaugeVec -} - -// NewConnectionPoolMetrics returns a new metrics collector for postgres -// connection pools. -func NewConnectionPoolMetrics() *ConnectionPoolMetrics { - m := ConnectionPoolMetrics{ - poolMetrics: map[string]struct{}{}, - hits: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_connection_pool_hits", - Help: "Collects number of times free connection was found in the pool", - }, - []string{"database", "pool"}, - ), - misses: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_connection_pool_misses", - Help: "Collects number of times free connection was NOT found in the pool", - }, - []string{"database", "pool"}, - ), - timeouts: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_connection_pool_timeouts", - Help: "Collects number of times a wait timeout occurred", - }, - []string{"database", "pool"}, - ), - totalConns: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "pace_postgres_connection_pool_total_conns", - Help: "Collects number of total connections in the pool", - }, - []string{"database", "pool"}, - ), - idleConns: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "pace_postgres_connection_pool_idle_conns", - Help: "Collects number of idle connections in the pool", - }, - []string{"database", "pool"}, - ), - staleConns: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "pace_postgres_connection_pool_stale_conns", - Help: "Collects number of stale connections removed from the pool", - }, - []string{"database", "pool"}, - ), - } - return &m -} - -// The metrics implement the prometheus collector methods. This allows to -// register them directly with a registry. -var _ prometheus.Collector = (*ConnectionPoolMetrics)(nil) - -// Describe descibes all the embedded prometheus metrics. -func (m *ConnectionPoolMetrics) Describe(ch chan<- *prometheus.Desc) { - m.hits.Describe(ch) - m.misses.Describe(ch) - m.timeouts.Describe(ch) - m.totalConns.Describe(ch) - m.idleConns.Describe(ch) - m.staleConns.Describe(ch) -} - -// Collect collects all the embedded prometheus metrics. -func (m *ConnectionPoolMetrics) Collect(ch chan<- prometheus.Metric) { - m.hits.Collect(ch) - m.misses.Collect(ch) - m.timeouts.Collect(ch) - m.totalConns.Collect(ch) - m.idleConns.Collect(ch) - m.staleConns.Collect(ch) -} - -// ObserveRegularly starts observing the given postgres pool. The provided pool -// name must be unique as it distinguishes multiple pools. The pool name is -// exposed as the "pool" label in the metrics. The metrics are collected once -// per minute for as long as the passed context is valid. -func (m *ConnectionPoolMetrics) ObserveRegularly(ctx context.Context, db *pg.DB, poolName string) error { - trigger := make(chan chan<- struct{}) - if err := m.ObserveWhenTriggered(trigger, db, poolName); err != nil { - return err - } - - // Trigger once a minute until context is cancelled. In the following - // goroutine we create a ticker that writes to a channel every minute. If - // this happens we write to the trigger channel and that will trigger - // observing the metrics. Both channel operations are blocking which is why - // we have to check the context two times. So that the goroutine doesn't - // stick around forever which would prevent the garbage collection from - // cleaning up the related resources. - go func() { - ticker := time.NewTicker(time.Minute) - defer close(trigger) - for { - select { - case <-ticker.C: - select { - // The trigger channel allows passing another channel if we - // wanted to get notified when observing the metrics is done. - // But we don't, so we just pass nil. - case trigger <- nil: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } - } - }() - - return nil -} - -// ObserveWhenTriggered starts observing the given postgres pool. The pool name -// behaves as decribed for the ObserveRegularly method. The metrics are observed -// for every emitted value from the trigger channel. The trigger channel allows -// passing a response channel that will be closed once the metrics were -// collected. It is also possible to pass nil. You should close the trigger -// channel when done to allow cleaning up. -func (m *ConnectionPoolMetrics) ObserveWhenTriggered(trigger <-chan chan<- struct{}, db *pg.DB, poolName string) error { - // check that pool name is unique - m.poolMetricsMx.Lock() - defer m.poolMetricsMx.Unlock() - if _, ok := m.poolMetrics[poolName]; ok { - return fmt.Errorf("invalid pool name: %q: %w", poolName, ErrNotUnique) - } - m.poolMetrics[poolName] = struct{}{} - - // start goroutine - go m.gatherConnectionPoolMetrics(trigger, db, poolName) - return nil -} - -func (m *ConnectionPoolMetrics) gatherConnectionPoolMetrics(trigger <-chan chan<- struct{}, db *pg.DB, poolName string) { - // prepare labels for all stats - opts := db.Options() - labels := prometheus.Labels{ - "database": opts.Addr + "/" + opts.Database, - "pool": poolName, - } - - // keep previous stats for the counters - var prevStats pg.PoolStats - - // collect all the pool stats whenever triggered - for done := range trigger { - stats := db.PoolStats() - // counters - m.hits.With(labels).Add(float64(stats.Hits - prevStats.Hits)) - m.misses.With(labels).Add(float64(stats.Misses - prevStats.Misses)) - m.timeouts.With(labels).Add(float64(stats.Timeouts - prevStats.Timeouts)) - // gauges - m.totalConns.With(labels).Set(float64(stats.TotalConns)) - m.idleConns.With(labels).Set(float64(stats.IdleConns)) - m.staleConns.With(labels).Set(float64(stats.StaleConns)) - // inform caller that we are done - if done != nil { - close(done) - } - prevStats = *stats - } -} diff --git a/backend/postgres/metrics_test.go b/backend/postgres/metrics_test.go deleted file mode 100644 index 98ecd8ffa..000000000 --- a/backend/postgres/metrics_test.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright © 2019 by PACE Telematics GmbH. All rights reserved. - -package postgres_test - -import ( - "context" - "errors" - "net/http/httptest" - "testing" - "time" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - . "github.com/pace/bricks/backend/postgres" -) - -func ExampleConnectionPoolMetrics() { - myDB := ConnectionPool() - - // collect stats about my db every minute - metrics := NewConnectionPoolMetrics() - if err := metrics.ObserveRegularly(context.Background(), myDB, "my_db"); err != nil { - panic(err) - } - prometheus.MustRegister(metrics) -} - -func TestIntegrationConnectionPoolMetrics(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - // prepare connection pool with metrics - metricsRegistry := prometheus.NewRegistry() - metrics := NewConnectionPoolMetrics() - metricsRegistry.MustRegister(metrics) - db := ConnectionPool() - trigger := make(chan chan<- struct{}) - err := metrics.ObserveWhenTriggered(trigger, db, "test") - require.NoError(t, err) - // collect some metrics - if _, err := db.Exec(`SELECT 1;`); err != nil { - t.Fatalf("could not query postgres database: %s", err) - } - whenDone := make(chan struct{}) - select { - case trigger <- whenDone: - case <-time.After(time.Second): - t.Fatal("did not start collecting metrics after 1s") - } - select { - case <-whenDone: - case <-time.After(time.Second): - t.Fatal("metrics were not collected after 1s") - } - // query metrics - resp := httptest.NewRecorder() - handler := promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) - handler.ServeHTTP(resp, httptest.NewRequest("GET", "/metrics", nil)) - body := resp.Body.String() - assert.Regexp(t, `pace_postgres_connection_pool_hits.*?\Wpool="test"\W`, body) - assert.Regexp(t, `pace_postgres_connection_pool_misses.*?\Wpool="test"\W`, body) - assert.Regexp(t, `pace_postgres_connection_pool_timeouts.*?\Wpool="test"\W`, body) - assert.Regexp(t, `pace_postgres_connection_pool_total_conns.*?\Wpool="test"\W`, body) - assert.Regexp(t, `pace_postgres_connection_pool_idle_conns.*?\Wpool="test"\W`, body) - assert.Regexp(t, `pace_postgres_connection_pool_stale_conns.*?\Wpool="test"\W`, body) -} - -// Tests that the NewConnectionPoolMetrics don't allow registering pools using -// the same pool name. -func TestIntegrationConnectionPoolMetrics_duplicatePoolName(t *testing.T) { - metrics := NewConnectionPoolMetrics() - // register first with name "test" - err := metrics.ObserveRegularly(context.Background(), ConnectionPool(), "test") - require.NoError(t, err) - // registering second with name "test" fails - err = metrics.ObserveRegularly(context.Background(), ConnectionPool(), "test") - assert.True(t, errors.Is(err, ErrNotUnique)) -} diff --git a/backend/postgres/options.go b/backend/postgres/options.go index cca80268b..14389c067 100644 --- a/backend/postgres/options.go +++ b/backend/postgres/options.go @@ -1,10 +1,8 @@ -// Copyright © 2022 by PACE Telematics GmbH. All rights reserved. +// Copyright © 2024 by PACE Telematics GmbH. All rights reserved. package postgres -import ( - "time" -) +import "time" type ConfigOption func(cfg *Config) @@ -58,36 +56,6 @@ func WithApplicationName(applicationName string) ConfigOption { } } -// WithMaxRetries - Maximum number of retries before giving up. -func WithMaxRetries(maxRetries int) ConfigOption { - return func(cfg *Config) { - cfg.MaxRetries = maxRetries - } -} - -// WithRetryStatementTimeout - Whether to retry queries cancelled because of statement_timeout. -func WithRetryStatementTimeout(retryStatementTimeout bool) ConfigOption { - return func(cfg *Config) { - cfg.RetryStatementTimeout = retryStatementTimeout - } -} - -// WithMinRetryBackoff - Minimum backoff between each retry. -// -1 disables backoff. -func WithMinRetryBackoff(minRetryBackoff time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.MinRetryBackoff = minRetryBackoff - } -} - -// WithMaxRetryBackoff - Maximum backoff between each retry. -// -1 disables backoff. -func WithMaxRetryBackoff(maxRetryBackoff time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.MaxRetryBackoff = maxRetryBackoff - } -} - // WithDialTimeout - Dial timeout for establishing new connections. func WithDialTimeout(dialTimeout time.Duration) ConfigOption { return func(cfg *Config) { @@ -110,67 +78,3 @@ func WithWriteTimeout(writeTimeout time.Duration) ConfigOption { cfg.WriteTimeout = writeTimeout } } - -// WithPoolSize - Maximum number of socket connections. -func WithPoolSize(poolSize int) ConfigOption { - return func(cfg *Config) { - cfg.PoolSize = poolSize - } -} - -// WithMinIdleConns - Minimum number of idle connections which is useful when establishing -// new connection is slow. -func WithMinIdleConns(minIdleConns int) ConfigOption { - return func(cfg *Config) { - cfg.MinIdleConns = minIdleConns - } -} - -// WithMaxConnAge - Connection age at which client retires (closes) the connection. -// It is useful with proxies like PgBouncer and HAProxy. -func WithMaxConnAge(maxConnAge time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.MaxConnAge = maxConnAge - } -} - -// WithPoolTimeout - Time for which client waits for free connection if all -// connections are busy before returning an error. -func WithPoolTimeout(poolTimeout time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.PoolTimeout = poolTimeout - } -} - -// WithIdleTimeout - Amount of time after which client closes idle connections. -// Should be less than server's timeout. -// -1 disables idle timeout check. -func WithIdleTimeout(idleTimeout time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.IdleTimeout = idleTimeout - } -} - -// WithIdleCheckFrequency - Frequency of idle checks made by idle connection's reaper. -// -1 disables idle connection's reaper, -// but idle connections are still discarded by the client -// if IdleTimeout is set. -func WithIdleCheckFrequency(idleCheckFrequency time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.IdleCheckFrequency = idleCheckFrequency - } -} - -// WithHealthCheckTableName - Name of the Table that is created to try if database is writeable -func WithHealthCheckTableName(healthCheckTableName string) ConfigOption { - return func(cfg *Config) { - cfg.HealthCheckTableName = healthCheckTableName - } -} - -// WithHealthCheckResultTTL - Amount of time to cache the last health check result -func WithHealthCheckResultTTL(healthCheckResultTTL time.Duration) ConfigOption { - return func(cfg *Config) { - cfg.HealthCheckResultTTL = healthCheckResultTTL - } -} diff --git a/backend/postgres/options_test.go b/backend/postgres/options_test.go deleted file mode 100644 index 0ea808de4..000000000 --- a/backend/postgres/options_test.go +++ /dev/null @@ -1,203 +0,0 @@ -package postgres - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestWithApplicationName(t *testing.T) { - param := "ApplicationName" - var conf Config - f := WithApplicationName(param) - f(&conf) - require.Equal(t, conf.ApplicationName, param) -} - -func TestWithDatabase(t *testing.T) { - param := "Database" - var conf Config - f := WithDatabase(param) - f(&conf) - require.Equal(t, conf.Database, param) -} - -func TestWithDialTimeout(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithDialTimeout(param) - f(&conf) - require.Equal(t, conf.DialTimeout, param) -} - -func TestWithHealthCheckResultTTL(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithHealthCheckResultTTL(param) - f(&conf) - require.Equal(t, conf.HealthCheckResultTTL, param) -} - -func TestWithHealthCheckTableName(t *testing.T) { - param := "HealthCheckTableName" - var conf Config - f := WithHealthCheckTableName(param) - f(&conf) - require.Equal(t, conf.HealthCheckTableName, param) -} - -func TestWithHost(t *testing.T) { - param := "Host" - var conf Config - f := WithHost(param) - f(&conf) - require.Equal(t, conf.Host, param) -} - -func TestWithIdleCheckFrequency(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithIdleCheckFrequency(param) - f(&conf) - require.Equal(t, conf.IdleCheckFrequency, param) -} - -func TestWithIdleTimeout(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithIdleTimeout(param) - f(&conf) - require.Equal(t, conf.IdleTimeout, param) -} - -func TestWithMaxConnAge(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithMaxConnAge(param) - f(&conf) - require.Equal(t, conf.MaxConnAge, param) -} - -func TestWithMaxRetries(t *testing.T) { - param := 42 - var conf Config - f := WithMaxRetries(param) - f(&conf) - require.Equal(t, conf.MaxRetries, param) -} - -func TestWithMaxRetryBackoff(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithMaxRetryBackoff(param) - f(&conf) - require.Equal(t, conf.MaxRetryBackoff, param) -} - -func TestWithMinIdleConns(t *testing.T) { - param := 42 - var conf Config - f := WithMinIdleConns(param) - f(&conf) - require.Equal(t, conf.MinIdleConns, param) -} - -func TestWithMinRetryBackoff(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithMinRetryBackoff(param) - f(&conf) - require.Equal(t, conf.MinRetryBackoff, param) -} - -func TestWithPassword(t *testing.T) { - param := "Password" - var conf Config - f := WithPassword(param) - f(&conf) - require.Equal(t, conf.Password, param) -} - -func TestWithPoolSize(t *testing.T) { - param := 42 - var conf Config - f := WithPoolSize(param) - f(&conf) - require.Equal(t, conf.PoolSize, param) -} - -func TestWithPoolTimeout(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithPoolTimeout(param) - f(&conf) - require.Equal(t, conf.PoolTimeout, param) -} - -func TestWithPort(t *testing.T) { - param := 42 - var conf Config - f := WithPort(param) - f(&conf) - require.Equal(t, conf.Port, param) -} - -func TestWithReadTimeout(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithReadTimeout(param) - f(&conf) - require.Equal(t, conf.ReadTimeout, param) -} - -func TestWithRetryStatementTimeout(t *testing.T) { - param := true - var conf Config - f := WithRetryStatementTimeout(param) - f(&conf) - require.Equal(t, conf.RetryStatementTimeout, param) -} - -func TestWithUser(t *testing.T) { - param := "User" - var conf Config - f := WithUser(param) - f(&conf) - require.Equal(t, conf.User, param) -} - -func TestWithWriteTimeout(t *testing.T) { - param := 5 * time.Second - var conf Config - f := WithWriteTimeout(param) - f(&conf) - require.Equal(t, conf.WriteTimeout, param) -} - -func TestWithLogReadWriteOnly(t *testing.T) { - cases := [][]bool{ - { - true, true, - }, - { - false, true, - }, - { - true, false, - }, - { - false, false, - }, - } - for _, tc := range cases { - read := tc[0] - write := tc[1] - var conf Config - f := WithQueryLogging(read, write) - f(&conf) - assert.Equal(t, conf.LogRead, read) - assert.Equal(t, conf.LogWrite, write) - } -} diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index ec0b27af0..ea3d24b09 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -1,28 +1,24 @@ -// Copyright © 2018 by PACE Telematics GmbH. All rights reserved. +// Copyright © 2024 by PACE Telematics GmbH. All rights reserved. -// Package postgres helps creating PostgreSQL connection pools package postgres import ( "context" - "fmt" - "math" + "database/sql" + "net" "os" "path/filepath" - "regexp" - "strings" - "sync" + "strconv" "time" - "github.com/opentracing/opentracing-go" - olog "github.com/opentracing/opentracing-go/log" - "github.com/rs/zerolog" - "github.com/caarlos0/env/v10" - "github.com/go-pg/pg" + "github.com/pace/bricks/backend/postgres/hooks" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/prometheus/client_golang/prometheus" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" + "github.com/uptrace/bun/driver/pgdriver" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" ) @@ -36,44 +32,8 @@ type Config struct { // ApplicationName is the application name. Used in logs on Pg side. // Only availaible from pg-9.0. ApplicationName string `env:"POSTGRES_APPLICATION_NAME" envDefault:"-"` - // Maximum number of retries before giving up. - MaxRetries int `env:"POSTGRES_MAX_RETRIES" envDefault:"5"` - // Whether to retry queries cancelled because of statement_timeout. - RetryStatementTimeout bool `env:"POSTGRES_RETRY_STATEMENT_TIMEOUT" envDefault:"false"` - // Minimum backoff between each retry. - // -1 disables backoff. - MinRetryBackoff time.Duration `env:"POSTGRES_MIN_RETRY_BACKOFF" envDefault:"250ms"` - // Maximum backoff between each retry. - // -1 disables backoff. - MaxRetryBackoff time.Duration `env:"POSTGRES_MAX_RETRY_BACKOFF" envDefault:"4s"` // Dial timeout for establishing new connections. DialTimeout time.Duration `env:"POSTGRES_DIAL_TIMEOUT" envDefault:"5s"` - // Timeout for socket reads. If reached, commands will fail - // with a timeout instead of blocking. - ReadTimeout time.Duration `env:"POSTGRES_READ_TIMEOUT" envDefault:"30s"` - // Timeout for socket writes. If reached, commands will fail - // with a timeout instead of blocking. - WriteTimeout time.Duration `env:"POSTGRES_WRITE_TIMEOUT" envDefault:"30s"` - // Maximum number of socket connections. - PoolSize int `env:"POSTGRES_POOL_SIZE" envDefault:"100"` - // Minimum number of idle connections which is useful when establishing - // new connection is slow. - MinIdleConns int `env:"POSTGRES_MIN_IDLE_CONNECTIONS" envDefault:"10"` - // Connection age at which client retires (closes) the connection. - // It is useful with proxies like PgBouncer and HAProxy. - MaxConnAge time.Duration `env:"POSTGRES_MAX_CONN_AGE" envDefault:"30m"` - // Time for which client waits for free connection if all - // connections are busy before returning an error. - PoolTimeout time.Duration `env:"POSTGRES_POOL_TIMEOUT" envDefault:"31s"` - // Amount of time after which client closes idle connections. - // Should be less than server's timeout. - // -1 disables idle timeout check. - IdleTimeout time.Duration `env:"POSTGRES_IDLE_TIMEOUT" envDefault:"5m"` - // Frequency of idle checks made by idle connections reaper. - // -1 disables idle connections reaper, - // but idle connections are still discarded by the client - // if IdleTimeout is set. - IdleCheckFrequency time.Duration `env:"POSTGRES_IDLE_CHECK_FREQUENCY" envDefault:"1m"` // Name of the Table that is created to try if database is writeable HealthCheckTableName string `env:"POSTGRES_HEALTH_CHECK_TABLE_NAME" envDefault:"healthcheck"` // Amount of time to cache the last health check result @@ -82,57 +42,22 @@ type Config struct { LogWrite bool `env:"POSTGRES_LOG_WRITES" envDefault:"true"` // Indicator whether read (select) queries should be logged LogRead bool `env:"POSTGRES_LOG_READS" envDefault:"false"` + // Timeout for socket reads. If reached, commands will fail + // with a timeout instead of blocking. + ReadTimeout time.Duration `env:"POSTGRES_READ_TIMEOUT" envDefault:"30s"` + // Timeout for socket writes. If reached, commands will fail + // with a timeout instead of blocking. + WriteTimeout time.Duration `env:"POSTGRES_WRITE_TIMEOUT" envDefault:"30s"` } -var ( - metricQueryTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_query_total", - Help: "Collects stats about the number of postgres queries made", - }, - []string{"database"}, - ) - metricQueryFailed = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_query_failed", - Help: "Collects stats about the number of postgres queries failed", - }, - []string{"database"}, - ) - metricQueryDurationSeconds = prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "pace_postgres_query_duration_seconds", - Help: "Collect performance metrics for each postgres query", - Buckets: []float64{.1, .25, .5, 1, 2.5, 5, 10, 60}, - }, - []string{"database"}, - ) - metricQueryRowsTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_query_rows_total", - Help: "Collects stats about the number of rows returned by a postgres query", - }, - []string{"database"}, - ) - metricQueryAffectedTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "pace_postgres_query_affected_total", - Help: "Collects stats about the number of rows affected by a postgres query", - }, - []string{"database"}, - ) -) - var cfg Config func init() { - prometheus.MustRegister(metricQueryTotal) - prometheus.MustRegister(metricQueryFailed) - prometheus.MustRegister(metricQueryDurationSeconds) - prometheus.MustRegister(metricQueryRowsTotal) - prometheus.MustRegister(metricQueryAffectedTotal) + prometheus.MustRegister(hooks.MetricQueryTotal) + prometheus.MustRegister(hooks.MetricQueryFailed) + prometheus.MustRegister(hooks.MetricQueryDurationSeconds) + prometheus.MustRegister(hooks.MetricQueryAffectedTotal) - // parse log Config err := env.Parse(&cfg) if err != nil { log.Fatalf("Failed to parse postgres environment: %v", err) @@ -154,229 +79,44 @@ func init() { } } - servicehealthcheck.RegisterHealthCheck("postgresdefault", &HealthCheck{ - Pool: &pgPoolAdapter{db: DefaultConnectionPool()}, - }) + servicehealthcheck.RegisterHealthCheck("postgresdefault", NewHealthCheck(NewDB(context.Background()))) } -var ( - defaultPool *pg.DB - defaultPoolOnce sync.Once -) +func NewDB(ctx context.Context, options ...ConfigOption) *bun.DB { + for _, opt := range options { + opt(&cfg) + } + + connector := pgdriver.NewConnector( + pgdriver.WithAddr(net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port))), + pgdriver.WithApplicationName(cfg.ApplicationName), + pgdriver.WithDatabase(cfg.Database), + pgdriver.WithDialTimeout(cfg.DialTimeout), + pgdriver.WithPassword(cfg.Password), + pgdriver.WithReadTimeout(cfg.ReadTimeout), + pgdriver.WithUser(cfg.User), + pgdriver.WithWriteTimeout(cfg.WriteTimeout), + pgdriver.WithInsecure(true), + ) -// DefaultConnectionPool returns a the default database connection pool that is -// configured using the POSTGRES_* env vars and instrumented with tracing, -// logging and metrics. -func DefaultConnectionPool() *pg.DB { - var err error - defaultPoolOnce.Do(func() { - if defaultPool == nil { - defaultPool = ConnectionPool() - // add metrics - metrics := NewConnectionPoolMetrics() - prometheus.MustRegister(metrics) - err = metrics.ObserveRegularly(context.Background(), defaultPool, "default") - } - }) - if err != nil { - panic(err) - } - return defaultPool -} + sqldb := sql.OpenDB(connector) + db := bun.NewDB(sqldb, pgdialect.New()) -// ConnectionPool returns a new database connection pool -// that is already configured with the correct credentials and -// instrumented with tracing and logging -// Used Config is taken from the env and it's default values. These -// values can be overwritten by the use of ConfigOption. -func ConnectionPool(opts ...ConfigOption) *pg.DB { - // apply functional options if given to overwrite the default config / env config - for _, f := range opts { - f(&cfg) - } + log.Ctx(ctx).Info().Str("addr", connector.Config().Addr). + Str("user", connector.Config().User). + Str("database", connector.Config().Database). + Str("as", connector.Config().AppName). + Msg("PostgreSQL connection pool created") - return CustomConnectionPool(&pg.Options{ - Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), - User: cfg.User, - Password: cfg.Password, - Database: cfg.Database, - ApplicationName: cfg.ApplicationName, - MaxRetries: cfg.MaxRetries, - RetryStatementTimeout: cfg.RetryStatementTimeout, - MinRetryBackoff: cfg.MinRetryBackoff, - MaxRetryBackoff: cfg.MaxRetryBackoff, - DialTimeout: cfg.DialTimeout, - ReadTimeout: cfg.ReadTimeout, - WriteTimeout: cfg.WriteTimeout, - PoolSize: cfg.PoolSize, - MinIdleConns: cfg.MinIdleConns, - MaxConnAge: cfg.MaxConnAge, - PoolTimeout: cfg.PoolTimeout, - IdleTimeout: cfg.IdleTimeout, - IdleCheckFrequency: cfg.IdleCheckFrequency, - }) -} + // Add hooks + db.AddQueryHook(&hooks.TracingHook{}) + db.AddQueryHook(hooks.NewMetricsHook(cfg.Host, cfg.Database)) -// CustomConnectionPool returns a new database connection pool -// that is already configured with the correct credentials and -// instrumented with tracing and logging using the passed options -// -// Fot a health check for this connection a PgHealthCheck needs to -// be registered: -// -// servicehealthcheck.RegisterHealthCheck(...) -func CustomConnectionPool(opts *pg.Options) *pg.DB { - log.Logger().Info().Str("addr", opts.Addr). - Str("user", opts.User). - Str("database", opts.Database). - Str("as", opts.ApplicationName). - Msg("PostgreSQL connection pool created") - db := pg.Connect(opts) if cfg.LogWrite || cfg.LogRead { - db.OnQueryProcessed(queryLogger) + db.AddQueryHook(hooks.NewLoggingHook(cfg.LogRead, cfg.LogWrite)) } else { - log.Logger().Warn().Msg("Connection pool has logging queries disabled completely") + log.Ctx(ctx).Warn().Msg("Connection pool has logging queries disabled completely") } - db.OnQueryProcessed(openTracingAdapter) - db.OnQueryProcessed(func(event *pg.QueryProcessedEvent) { - metricsAdapter(event, opts) - }) - return db } - -type queryMode int - -const ( - readMode queryMode = iota - writeMode queryMode = iota -) - -// determineQueryMode is a poorman's attempt at checking whether the query is a read or write to the database. -// Feel free to improve. -func determineQueryMode(qry string) queryMode { - if strings.HasPrefix(strings.ToLower(strings.TrimSpace(qry)), "select") { - return readMode - } - return writeMode -} - -func queryLogger(event *pg.QueryProcessedEvent) { - q, qe := event.UnformattedQuery() - if qe == nil { - if !(cfg.LogRead || cfg.LogWrite) { - return - } - // we can only and should only perfom the following check if we have the information availaible - mode := determineQueryMode(q) - if mode == readMode && !cfg.LogRead { - return - } - if mode == writeMode && !cfg.LogWrite { - return - } - - } - ctx := event.DB.Context() - dur := float64(time.Since(event.StartTime)) / float64(time.Millisecond) - - // check if log context is given - var logger *zerolog.Logger - if ctx != nil { - logger = log.Ctx(ctx) - } else { - logger = log.Logger() - } - - // add general info - le := logger.Debug(). - Str("file", event.File). - Int("line", event.Line). - Str("func", event.Func). - Int("attempt", event.Attempt). - Float64("duration", dur). - Str("sentry:category", "postgres") - - // add error or result set info - if event.Error != nil { - le = le.Err(event.Error) - } else { - le = le.Int("affected", event.Result.RowsAffected()). - Int("rows", event.Result.RowsReturned()) - } - - if qe != nil { - // this is only a display issue not a "real" issue - le.Msgf("%v", qe) - } - le.Msg(q) -} - -var ( - reQueryType = regexp.MustCompile(`(\s)`) - reQueryTypeCleanup = regexp.MustCompile(`(?m)(\s+|\n)`) -) - -func getQueryType(s string) string { - s = reQueryTypeCleanup.ReplaceAllString(s, " ") - s = strings.TrimSpace(s) - - p := reQueryType.FindStringIndex(s) - if len(p) > 0 { - return strings.ToUpper(s[:p[0]]) - } - return strings.ToUpper(s) -} - -func openTracingAdapter(event *pg.QueryProcessedEvent) { - // start span with general info - q, qe := event.UnformattedQuery() - if qe != nil { - // this is only a display issue not a "real" issue - q = qe.Error() - } - - span, _ := opentracing.StartSpanFromContext(event.DB.Context(), "sql: "+getQueryType(q), - opentracing.StartTime(event.StartTime)) - - span.SetTag("db.system", "postgres") - - fields := []olog.Field{ - olog.String("file", event.File), - olog.Int("line", event.Line), - olog.String("func", event.Func), - olog.Int("attempt", event.Attempt), - olog.String("query", q), - } - - // add error or result set info - if event.Error != nil { - fields = append(fields, olog.Error(event.Error)) - } else { - fields = append(fields, - olog.Int("affected", event.Result.RowsAffected()), - olog.Int("rows", event.Result.RowsReturned())) - } - - span.LogFields(fields...) - span.Finish() -} - -func metricsAdapter(event *pg.QueryProcessedEvent, opts *pg.Options) { - dur := float64(time.Since(event.StartTime)) / float64(time.Millisecond) - labels := prometheus.Labels{ - "database": opts.Addr + "/" + opts.Database, - } - - metricQueryTotal.With(labels).Inc() - - if event.Error != nil { - metricQueryFailed.With(labels).Inc() - } else { - r := event.Result - metricQueryRowsTotal.With(labels).Add(float64(r.RowsReturned())) - metricQueryAffectedTotal.With(labels).Add(math.Max(0, float64(r.RowsAffected()))) - } - metricQueryDurationSeconds.With(labels).Observe(dur) -} diff --git a/backend/postgres/postgres_test.go b/backend/postgres/postgres_test.go deleted file mode 100644 index 5206720e9..000000000 --- a/backend/postgres/postgres_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright © 2018 by PACE Telematics GmbH. All rights reserved. - -package postgres - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestIntegrationConnectionPool(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - db := ConnectionPool() - var result struct { - Calc int - } - _, err := db.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) //nolint:errcheck - if err != nil { - t.Errorf("got %v", err) - } - - // Note: This test can't actually test the logging correctly - // but the code will be accessed -} - -func TestIntegrationConnectionPoolNoLogging(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - db := ConnectionPool(WithQueryLogging(false, false)) - var result struct { - Calc int - } - _, err := db.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) //nolint:errcheck - if err != nil { - t.Errorf("got %v", err) - } - - // Note: This test can't actually test the logging correctly - // but the code will be accessed -} - -func TestGetQueryType(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - - testQuery1 := `SELECT * FROM example` - require.Equal(t, "SELECT", getQueryType(testQuery1)) - - testQuery2 := ` - SELECT - * FROM example` - require.Equal(t, "SELECT", getQueryType(testQuery2)) - - testQuery3 := `UPDATE example SET foo = 1` - require.Equal(t, "UPDATE", getQueryType(testQuery3)) - - testQuery4 := `COPY film_locations FROM '/tmp/foo.csv' HEADER CSV DELIMITER ',';` - require.Equal(t, "COPY", getQueryType(testQuery4)) -} diff --git a/backend/postgres/query_ctx.go b/backend/postgres/query_ctx.go deleted file mode 100644 index 58515082d..000000000 --- a/backend/postgres/query_ctx.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright © 2020 by PACE Telematics GmbH. All rights reserved. - -package postgres - -import ( - "context" - - "github.com/go-pg/pg" - "github.com/go-pg/pg/orm" -) - -type pgPoolAdapter struct { - db *pg.DB -} - -func (a *pgPoolAdapter) Exec(ctx context.Context, query interface{}, params ...interface{}) (res orm.Result, err error) { - db := a.db.WithContext(ctx) - return db.Exec(query, params...) -} diff --git a/http/jsonapi/runtime/standard_params.go b/http/jsonapi/runtime/standard_params.go index c45859d55..9a9d00248 100644 --- a/http/jsonapi/runtime/standard_params.go +++ b/http/jsonapi/runtime/standard_params.go @@ -9,8 +9,7 @@ import ( "strings" "github.com/caarlos0/env/v10" - "github.com/go-pg/pg" - "github.com/go-pg/pg/orm" + "github.com/uptrace/bun" "github.com/pace/bricks/maintenance/log" ) @@ -68,7 +67,7 @@ type UrlQueryParameters struct { PageNr int PageSize int Order []string - Filter map[string][]interface{} + Filter map[string][]any } // ReadURLQueryParameters reads sorting, filter and pagination from requests and return a UrlQueryParameters object, @@ -98,11 +97,12 @@ func ReadURLQueryParameters(r *http.Request, mapper ColumnMapper, sanitizer Valu return result, fmt.Errorf("reading URL Query Parameters cased multiple errors: %v", strings.Join(errAggregate, ",")) } -// AddToQuery adds filter, sorting and pagination to a orm.Query -func (u *UrlQueryParameters) AddToQuery(query *orm.Query) *orm.Query { +// AddToQuery adds filter, sorting and pagination to a query. +func (u *UrlQueryParameters) AddToQuery(query *bun.SelectQuery) *bun.SelectQuery { if u.HasPagination { query.Offset(u.PageSize * u.PageNr).Limit(u.PageSize) } + for name, filterValues := range u.Filter { if len(filterValues) == 0 { continue @@ -112,11 +112,14 @@ func (u *UrlQueryParameters) AddToQuery(query *orm.Query) *orm.Query { query.Where(name+" = ?", filterValues[0]) continue } - query.Where(name+" IN (?)", pg.In(filterValues)) + + query.Where(name+" IN (?)", bun.In(filterValues)) } + for _, val := range u.Order { query.Order(val) } + return query } diff --git a/http/jsonapi/runtime/standard_params_test.go b/http/jsonapi/runtime/standard_params_test.go index 72be89027..23ccceb54 100644 --- a/http/jsonapi/runtime/standard_params_test.go +++ b/http/jsonapi/runtime/standard_params_test.go @@ -8,13 +8,12 @@ import ( "sort" "testing" - "github.com/go-pg/pg" - "github.com/go-pg/pg/orm" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" "github.com/pace/bricks/backend/postgres" "github.com/pace/bricks/http/jsonapi/runtime" - "github.com/pace/bricks/maintenance/log" ) type TestModel struct { @@ -31,125 +30,141 @@ func TestIntegrationFilterParameter(t *testing.T) { if testing.Short() { t.SkipNow() } + + ctx := context.Background() + // Setup - a := assert.New(t) - db := setupDatabase(a) + db := setupDatabase(ctx, t) + defer func() { // Tear Down - err := db.DropTable(&TestModel{}, &orm.DropTableOptions{}) - assert.NoError(t, err) + _, err := db.NewDropTable().Model((*TestModel)(nil)).IfExists().Exec(context.Background()) + require.NoError(t, err) }() mappingNames := map[string]string{ "test": "filter_name", } mapper := runtime.NewMapMapper(mappingNames) + // filter r := httptest.NewRequest("GET", "http://abc.de/whatEver?filter[test]=b", nil) + urlParams, err := runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) - a.NoError(err) + require.NoError(t, err) + var modelsFilter []TestModel - q := db.Model(&modelsFilter) + + q := db.NewSelect().Model(&modelsFilter) q = urlParams.AddToQuery(q) - count, _ := q.SelectAndCount() - a.Equal(1, count) - a.Equal("b", modelsFilter[0].FilterName) + + count, err := q.ScanAndCount(ctx) + require.NoError(t, err) + + assert.Equal(t, 1, count) + assert.Equal(t, "b", modelsFilter[0].FilterName) r = httptest.NewRequest("GET", "http://abc.de/whatEver?filter[test]=a,b", nil) + urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) - a.NoError(err) + require.NoError(t, err) + var modelsFilter2 []TestModel - q = db.Model(&modelsFilter2) + + q = db.NewSelect().Model(&modelsFilter2) q = urlParams.AddToQuery(q) - count, _ = q.SelectAndCount() - a.Equal(2, count) + + count, err = q.ScanAndCount(ctx) + require.NoError(t, err) + + assert.Equal(t, 2, count) + sort.Slice(modelsFilter2, func(i, j int) bool { return modelsFilter2[i].FilterName < modelsFilter2[j].FilterName }) - a.Equal("a", modelsFilter2[0].FilterName) - a.Equal("b", modelsFilter2[1].FilterName) + + assert.Equal(t, "a", modelsFilter2[0].FilterName) + assert.Equal(t, "b", modelsFilter2[1].FilterName) // Paging r = httptest.NewRequest("GET", "http://abc.de/whatEver?page[number]=1&page[size]=2", nil) + urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) - assert.NoError(t, err) + require.NoError(t, err) + var modelsPaging []TestModel - q = db.Model(&modelsPaging) + + q = db.NewSelect().Model(&modelsPaging) q = urlParams.AddToQuery(q) - err = q.Select() - a.NoError(err) + + err = q.Scan(ctx) + require.NoError(t, err) + sort.Slice(modelsPaging, func(i, j int) bool { return modelsPaging[i].FilterName < modelsPaging[j].FilterName }) - a.Equal("c", modelsPaging[0].FilterName) - a.Equal("d", modelsPaging[1].FilterName) + + assert.Equal(t, "c", modelsPaging[0].FilterName) + assert.Equal(t, "d", modelsPaging[1].FilterName) // Sorting r = httptest.NewRequest("GET", "http://abc.de/whatEver?sort=-test", nil) + urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) - assert.NoError(t, err) + require.NoError(t, err) + var modelsSort []TestModel - q = db.Model(&modelsSort) + + q = db.NewSelect().Model(&modelsSort) q = urlParams.AddToQuery(q) - err = q.Select() - a.NoError(err) - a.Equal(6, len(modelsSort)) - a.Equal("f", modelsSort[0].FilterName) - a.Equal("e", modelsSort[1].FilterName) - a.Equal("d", modelsSort[2].FilterName) - a.Equal("c", modelsSort[3].FilterName) - a.Equal("b", modelsSort[4].FilterName) - a.Equal("a", modelsSort[5].FilterName) + + err = q.Scan(ctx) + require.NoError(t, err) + + assert.Equal(t, 6, len(modelsSort)) + assert.Equal(t, "f", modelsSort[0].FilterName) + assert.Equal(t, "e", modelsSort[1].FilterName) + assert.Equal(t, "d", modelsSort[2].FilterName) + assert.Equal(t, "c", modelsSort[3].FilterName) + assert.Equal(t, "b", modelsSort[4].FilterName) + assert.Equal(t, "a", modelsSort[5].FilterName) // Combine all r = httptest.NewRequest("GET", "http://abc.de/whatEver?sort=-test&filter[test]=a,b,e,f&page[number]=1&page[size]=2", nil) + urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) - assert.NoError(t, err) + require.NoError(t, err) + var modelsCombined []TestModel - q = db.Model(&modelsCombined) + + q = db.NewSelect().Model(&modelsCombined) q = urlParams.AddToQuery(q) - err = q.Select() - assert.NoError(t, err) - a.Equal(2, len(modelsCombined)) - a.Equal("b", modelsCombined[0].FilterName) - a.Equal("a", modelsCombined[1].FilterName) + + err = q.Scan(ctx) + require.NoError(t, err) + + assert.Equal(t, 2, len(modelsCombined)) + assert.Equal(t, "b", modelsCombined[0].FilterName) + assert.Equal(t, "a", modelsCombined[1].FilterName) } -func setupDatabase(a *assert.Assertions) *pg.DB { - db := postgres.DefaultConnectionPool() - db = db.WithContext(log.WithContext(context.Background())) - - err := db.CreateTable(&TestModel{}, &orm.CreateTableOptions{}) - a.NoError(err) - _, err = db.Model(&TestModel{ - FilterName: "a", - }).Insert() - a.NoError(err) - - _, err = db.Model(&TestModel{ - FilterName: "b", - }).Insert() - a.NoError(err) - - _, err = db.Model(&TestModel{ - FilterName: "c", - }).Insert() - a.NoError(err) - - _, err = db.Model(&TestModel{ - FilterName: "d", - }).Insert() - a.NoError(err) - - _, err = db.Model(&TestModel{ - FilterName: "e", - }).Insert() - a.NoError(err) - - _, err = db.Model(&TestModel{ - FilterName: "f", - }).Insert() - a.NoError(err) +func setupDatabase(ctx context.Context, t *testing.T) *bun.DB { + db := postgres.NewDB(context.Background()) + + _, err := db.NewCreateTable().Model((*TestModel)(nil)).Exec(ctx) + require.NoError(t, err) + + testModels := []TestModel{ + {FilterName: "a"}, + {FilterName: "b"}, + {FilterName: "c"}, + {FilterName: "d"}, + {FilterName: "e"}, + {FilterName: "f"}, + } + + _, err = db.NewInsert().Model(&testModels).Exec(ctx) + require.NoError(t, err) return db } diff --git a/tools/testserver/main.go b/tools/testserver/main.go index c88b09924..51bd2fdec 100755 --- a/tools/testserver/main.go +++ b/tools/testserver/main.go @@ -66,7 +66,7 @@ func (*TestService) GetTest(ctx context.Context, w simple.GetTestResponseWriter, } func main() { - db := postgres.DefaultConnectionPool() + db := postgres.NewDB(context.Background()) rdb := redis.Client() cdb, err := couchdb.DefaultDatabase() if err != nil { @@ -100,16 +100,22 @@ func main() { defer handlerSpan.Finish() // do dummy database query - cdb := db.WithContext(ctx) var result struct { Calc int //nolint } - res, err := cdb.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) + res, err := db.NewSelect().Model(&result).ColumnExpr("? + ? AS Calc", 10, 10).Exec(ctx) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("Calc failed") return } - log.Ctx(ctx).Debug().Int("rows_affected", res.RowsAffected()).Msg("Calc done") + + count, err := res.RowsAffected() + if err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("RowsAffected failed") + return + } + + log.Ctx(ctx).Debug().Int64("rows_affected", count).Msg("Calc done") // do dummy redis query crdb := redis.WithContext(ctx, rdb)