From 2ec50be835be97fa194b29ad2cf468da6298e54a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sami=20Salih=20=C4=B0brahimba=C5=9F?= Date: Tue, 13 Aug 2024 11:51:44 +0300 Subject: [PATCH] feature/txnsql-adapter: add sql interface --- txnsql/sql.go | 18 ++++++++++++++++-- txnsql/sql_test.go | 10 +++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/txnsql/sql.go b/txnsql/sql.go index d4e62c0..8b412e1 100644 --- a/txnsql/sql.go +++ b/txnsql/sql.go @@ -7,13 +7,24 @@ import ( "github.com/9ssi7/txn" ) +type SqlDbTx interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Prepare(query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + Exec(query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + Query(query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + QueryRow(query string, args ...any) *sql.Row +} + // SqlAdapter is the interface for interacting with SQL databases within a transaction. // It extends the txn.Adapter interface to provide additional SQL-specific functionality. type SqlAdapter interface { txn.Adapter // Returns current transaction if it exists. - Tx() *sql.Tx + GetCurrent() SqlDbTx } // New creates a new SqlAdapter instance using the provided *sql.DB. @@ -59,6 +70,9 @@ func (a *sqlAdapter) End(_ context.Context) { } } -func (a *sqlAdapter) Tx() *sql.Tx { +func (a *sqlAdapter) GetCurrent() SqlDbTx { + if a.tx == nil { + return a.db + } return a.tx } diff --git a/txnsql/sql_test.go b/txnsql/sql_test.go index ebbfafc..c997e1b 100644 --- a/txnsql/sql_test.go +++ b/txnsql/sql_test.go @@ -141,23 +141,23 @@ func TestSqlAdapter_End(t *testing.T) { }) } -func TestSqlAdapter_Tx(t *testing.T) { +func TestSqlAdapter_GetCurrent(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer db.Close() - t.Run("Tx is not nil", func(t *testing.T) { + t.Run("Tx is not nil, current is *sql.Tx", func(t *testing.T) { adapter := &sqlAdapter{db: db} mock.ExpectBegin() adapter.Begin(context.Background()) - assert.True(t, adapter.Tx() != nil) + assert.True(t, adapter.GetCurrent() == adapter.tx) }) - t.Run("Tx is nil", func(t *testing.T) { + t.Run("Tx is nil, current is *sql.DB", func(t *testing.T) { adapter := &sqlAdapter{db: db} - assert.True(t, adapter.Tx() == nil) + assert.True(t, adapter.GetCurrent() == adapter.db) }) }