From 00d0a05dc37c403e372008cc7b381536e5914219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Sat, 17 Dec 2016 01:24:03 +0000 Subject: [PATCH] Add context to transactions --- internal/sqladapter/collection.go | 4 +-- internal/sqladapter/database.go | 7 +++-- internal/sqladapter/testing/adapter.go.tpl | 14 ++++----- internal/sqladapter/tx.go | 12 ++------ lib/sqlbuilder/wrapper.go | 7 +++-- mysql/database.go | 24 +++++++++------ mysql/mysql.go | 30 +----------------- postgresql/database.go | 24 +++++++++------ postgresql/local_test.go | 2 +- postgresql/postgresql.go | 36 +--------------------- ql/database.go | 19 +++++++----- sqlite/database.go | 13 ++++---- sqlite/sqlite.go | 2 +- 13 files changed, 71 insertions(+), 123 deletions(-) diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go index 320209da..b6a1058f 100644 --- a/internal/sqladapter/collection.go +++ b/internal/sqladapter/collection.go @@ -76,12 +76,12 @@ func (c *collection) InsertReturning(item interface{}) error { inTx := false if currTx := c.p.Database().Transaction(); currTx != nil { - tx = newTxWrapper(c.p.Database()) + tx = NewTx(c.p.Database()) inTx = true } else { // Not within a transaction, let's create one. var err error - tx, err = c.p.Database().NewLocalTransaction() + tx, err = c.p.Database().NewLocalTransaction(c.p.Database().Context()) if err != nil { return err } diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 4c6cd6d8..5f233092 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -54,7 +54,7 @@ type PartialDatabase interface { ConnectionURL() db.ConnectionURL Err(in error) (out error) - NewLocalTransaction() (DatabaseTx, error) + NewLocalTransaction(ctx context.Context) (DatabaseTx, error) } // BaseDatabase defines the methods provided by sqladapter that do not have to @@ -74,7 +74,7 @@ type BaseDatabase interface { BindSession(*sql.DB) error Session() *sql.DB - BindTx(*sql.Tx) error + BindTx(context.Context, *sql.Tx) error Transaction() BaseTx SetConnMaxLifetime(time.Duration) @@ -130,7 +130,7 @@ func (d *database) Session() *sql.DB { } // BindTx binds a *sql.Tx into *database -func (d *database) BindTx(t *sql.Tx) error { +func (d *database) BindTx(ctx context.Context, t *sql.Tx) error { d.sessMu.Lock() defer d.sessMu.Unlock() @@ -139,6 +139,7 @@ func (d *database) BindTx(t *sql.Tx) error { return err } + d.ctx = ctx d.txID = newTxID() return nil } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 1ddd7fe6..cfc40805 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -284,7 +284,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) { err := sess.Collection("artist").Truncate() assert.NoError(t, err) - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) defer tx.Close() @@ -1034,7 +1034,7 @@ func TestTransactionsAndRollback(t *testing.T) { sess := mustOpen() // Simple transaction that should not fail. - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) artist := tx.Collection("artist") @@ -1059,7 +1059,7 @@ func TestTransactionsAndRollback(t *testing.T) { assert.NoError(t, err) // Use another transaction. - tx, err = sess.NewTx() + tx, err = sess.NewTx(nil) assert.NoError(t, err) artist = tx.Collection("artist") @@ -1092,7 +1092,7 @@ func TestTransactionsAndRollback(t *testing.T) { assert.NoError(t, err) // Attempt to add some rows. - tx, err = sess.NewTx() + tx, err = sess.NewTx(nil) assert.NoError(t, err) artist = tx.Collection("artist") @@ -1123,7 +1123,7 @@ func TestTransactionsAndRollback(t *testing.T) { assert.NoError(t, err) // Attempt to add some rows. - tx, err = sess.NewTx() + tx, err = sess.NewTx(nil) assert.NoError(t, err) artist = tx.Collection("artist") @@ -1506,7 +1506,7 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) assert.NotZero(t, all) - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) assert.NotZero(t, tx) defer tx.Close() @@ -1556,7 +1556,7 @@ func TestExhaustConnectionPool(t *testing.T) { // Requesting a new transaction session. start := time.Now() tLogf("Tx: %d: NewTx", i) - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) if err != nil { tFatal(err) } diff --git a/internal/sqladapter/tx.go b/internal/sqladapter/tx.go index 3239f310..9ab7a8b6 100644 --- a/internal/sqladapter/tx.go +++ b/internal/sqladapter/tx.go @@ -22,6 +22,7 @@ package sqladapter import ( + "context" "database/sql" "sync/atomic" @@ -57,13 +58,6 @@ func NewTx(db Database) DatabaseTx { } } -func newTxWrapper(db Database) DatabaseTx { - return &txWrapper{ - Database: db, - BaseTx: db.Transaction(), - } -} - type sqlTx struct { *sql.Tx committed atomic.Value @@ -99,8 +93,8 @@ func (t *txWrapper) Rollback() error { } // RunTx creates a transaction context and runs fn within it. -func RunTx(d sqlbuilder.Database, fn func(tx sqlbuilder.Tx) error) error { - tx, err := d.NewTx() +func RunTx(d sqlbuilder.Database, ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + tx, err := d.NewTx(ctx) if err != nil { return err } diff --git a/lib/sqlbuilder/wrapper.go b/lib/sqlbuilder/wrapper.go index 2c2445d9..05078c2d 100644 --- a/lib/sqlbuilder/wrapper.go +++ b/lib/sqlbuilder/wrapper.go @@ -22,6 +22,7 @@ package sqlbuilder import ( + "context" "database/sql" "fmt" "sync" @@ -50,6 +51,8 @@ type Backend interface { type Tx interface { Backend db.Tx + + Context() context.Context } // Database represents a Database which is capable of both creating @@ -59,14 +62,14 @@ type Database interface { // NewTx returns a new session that lives within a transaction. This session // is completely independent from its parent. - NewTx() (Tx, error) + NewTx(ctx context.Context) (Tx, error) // Tx creates a new transaction that is passed as context to the fn function. // The fn function defines a transaction operation. If the fn function // returns nil, the transaction is commited, otherwise the transaction is // rolled back. The transaction session is closed after the function exists, // regardless of the error value returned by fn. - Tx(fn func(sess Tx) error) error + Tx(ctx context.Context, fn func(sess Tx) error) error } // AdapterFuncMap is a struct that defines a set of functions that adapters diff --git a/mysql/database.go b/mysql/database.go index 068865b3..905f8fec 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -22,6 +22,7 @@ package mysql import ( + "context" "strings" "sync" @@ -70,8 +71,11 @@ func (d *database) Open(connURL db.ConnectionURL) error { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + if ctx == nil { + ctx = d.Context() + } + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -112,9 +116,9 @@ func (d *database) open() error { connFn := func() error { sess, err := sql.Open("mysql", d.ConnectionURL().String()) if err == nil { - sess.SetConnMaxLifetime(connMaxLifetime) - sess.SetMaxIdleConns(maxIdleConns) - sess.SetMaxOpenConns(maxOpenConns) + sess.SetConnMaxLifetime(db.DefaultConnMaxLifetime) + sess.SetMaxIdleConns(db.DefaultMaxIdleConns) + sess.SetMaxOpenConns(db.DefaultMaxOpenConns) return d.BaseDatabase.BindSession(sess) } return err @@ -172,12 +176,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -187,9 +191,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() connFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil) if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/mysql/mysql.go b/mysql/mysql.go index 7f8da66e..eb4e752f 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -23,7 +23,6 @@ package mysql // import "upper.io/db.v3/mysql" import ( "database/sql" - "time" "upper.io/db.v3" @@ -79,7 +78,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } @@ -109,30 +108,3 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) { } return d, nil } - -// SetConnMaxLifetime sets the default value to be passed to -// db.SetConnMaxLifetime. -func SetConnMaxLifetime(d time.Duration) { - connMaxLifetime = d -} - -// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns. -func SetMaxIdleConns(n int) { - if n < 0 { - n = 0 - } - maxIdleConns = n -} - -// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns. -// If the value of maxIdleConns is >= 0 and maxOpenConns is less than -// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns. -func SetMaxOpenConns(n int) { - if n < 0 { - n = 0 - } - if n > maxIdleConns { - maxIdleConns = n - } - maxOpenConns = n -} diff --git a/postgresql/database.go b/postgresql/database.go index a95975e6..493947b3 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -22,6 +22,7 @@ package postgresql import ( + "context" "database/sql" "strings" "sync" @@ -69,8 +70,11 @@ func (d *database) Open(connURL db.ConnectionURL) error { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + if ctx == nil { + ctx = context.Background() + } + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -111,9 +115,9 @@ func (d *database) open() error { connFn := func() error { sess, err := sql.Open("postgres", d.ConnectionURL().String()) if err == nil { - sess.SetConnMaxLifetime(connMaxLifetime) - sess.SetMaxIdleConns(maxIdleConns) - sess.SetMaxOpenConns(maxOpenConns) + sess.SetConnMaxLifetime(db.DefaultConnMaxLifetime) + sess.SetMaxIdleConns(db.DefaultMaxIdleConns) + sess.SetMaxOpenConns(db.DefaultMaxOpenConns) return d.BaseDatabase.BindSession(sess) } return err @@ -172,12 +176,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -187,9 +191,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() connFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil) if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/postgresql/local_test.go b/postgresql/local_test.go index 1b7b5950..22f010cc 100644 --- a/postgresql/local_test.go +++ b/postgresql/local_test.go @@ -100,7 +100,7 @@ func TestIssue210(t *testing.T) { sess := mustOpen() defer sess.Close() - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) for i := range list { diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index bbbc5fff..e48e051c 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -23,7 +23,6 @@ package postgresql // import "upper.io/db.v3/postgresql" import ( "database/sql" - "time" "upper.io/db.v3" @@ -31,12 +30,6 @@ import ( "upper.io/db.v3/lib/sqlbuilder" ) -var ( - connMaxLifetime = db.DefaultConnMaxLifetime - maxIdleConns = db.DefaultMaxIdleConns - maxOpenConns = db.DefaultMaxOpenConns -) - const sqlDriver = `postgres` // Adapter is the public name of the adapter. @@ -79,7 +72,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } @@ -109,30 +102,3 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) { } return d, nil } - -// SetConnMaxLifetime sets the default value to be passed to -// db.SetConnMaxLifetime. -func SetConnMaxLifetime(d time.Duration) { - connMaxLifetime = d -} - -// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns. -func SetMaxIdleConns(n int) { - if n < 0 { - n = 0 - } - maxIdleConns = n -} - -// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns. -// If the value of maxIdleConns is >= 0 and maxOpenConns is less than -// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns. -func SetMaxOpenConns(n int) { - if n < 0 { - n = 0 - } - if n > maxIdleConns { - maxIdleConns = n - } - maxOpenConns = n -} diff --git a/ql/database.go b/ql/database.go index a71b951a..58566a48 100644 --- a/ql/database.go +++ b/ql/database.go @@ -113,7 +113,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } @@ -145,8 +145,11 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + if ctx == nil { + ctx = d.Context() + } + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -272,12 +275,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -287,9 +290,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() openFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil) if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/sqlite/database.go b/sqlite/database.go index f1efa167..b04113cb 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -22,6 +22,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -85,8 +86,8 @@ func (d *database) Open(connURL db.ConnectionURL) error { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -215,12 +216,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -232,7 +233,7 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { openFn := func() error { sqlTx, err := clone.BaseDatabase.Session().Begin() if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index 894459a2..5c22c275 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -72,7 +72,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } -- GitLab