Skip to content

Commit

Permalink
feature/txnsql-adapter: add sql interface
Browse files Browse the repository at this point in the history
  • Loading branch information
9ssi7 committed Aug 13, 2024
1 parent fb24d2d commit 2ec50be
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
18 changes: 16 additions & 2 deletions txnsql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@ import (
"github.com/9ssi7/txn"
)

type SqlDbTx interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Prepare(query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Exec(query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
Query(query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
QueryRow(query string, args ...any) *sql.Row
}

// SqlAdapter is the interface for interacting with SQL databases within a transaction.
// It extends the txn.Adapter interface to provide additional SQL-specific functionality.
type SqlAdapter interface {
txn.Adapter

// Returns current transaction if it exists.
Tx() *sql.Tx
GetCurrent() SqlDbTx
}

// New creates a new SqlAdapter instance using the provided *sql.DB.
Expand Down Expand Up @@ -59,6 +70,9 @@ func (a *sqlAdapter) End(_ context.Context) {
}
}

func (a *sqlAdapter) Tx() *sql.Tx {
func (a *sqlAdapter) GetCurrent() SqlDbTx {
if a.tx == nil {
return a.db
}
return a.tx
}
10 changes: 5 additions & 5 deletions txnsql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,23 @@ func TestSqlAdapter_End(t *testing.T) {
})
}

func TestSqlAdapter_Tx(t *testing.T) {
func TestSqlAdapter_GetCurrent(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()

t.Run("Tx is not nil", func(t *testing.T) {
t.Run("Tx is not nil, current is *sql.Tx", func(t *testing.T) {
adapter := &sqlAdapter{db: db}
mock.ExpectBegin()

adapter.Begin(context.Background())
assert.True(t, adapter.Tx() != nil)
assert.True(t, adapter.GetCurrent() == adapter.tx)
})

t.Run("Tx is nil", func(t *testing.T) {
t.Run("Tx is nil, current is *sql.DB", func(t *testing.T) {
adapter := &sqlAdapter{db: db}
assert.True(t, adapter.Tx() == nil)
assert.True(t, adapter.GetCurrent() == adapter.db)
})
}

0 comments on commit 2ec50be

Please sign in to comment.