diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go index 320209dabea8f68921b13fc0e0e629ed3d9230c6..b6a1058ffb6b9263b7fa6bc69352459916a98233 100644 --- a/internal/sqladapter/collection.go +++ b/internal/sqladapter/collection.go @@ -76,12 +76,12 @@ func (c *collection) InsertReturning(item interface{}) error { inTx := false if currTx := c.p.Database().Transaction(); currTx != nil { - tx = newTxWrapper(c.p.Database()) + tx = NewTx(c.p.Database()) inTx = true } else { // Not within a transaction, let's create one. var err error - tx, err = c.p.Database().NewLocalTransaction() + tx, err = c.p.Database().NewLocalTransaction(c.p.Database().Context()) if err != nil { return err } diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 4c6cd6d88806df48337f62f7f99ff7e09aea03dc..5f233092706e0fb665dcb3000aa9328e94745fab 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -54,7 +54,7 @@ type PartialDatabase interface { ConnectionURL() db.ConnectionURL Err(in error) (out error) - NewLocalTransaction() (DatabaseTx, error) + NewLocalTransaction(ctx context.Context) (DatabaseTx, error) } // BaseDatabase defines the methods provided by sqladapter that do not have to @@ -74,7 +74,7 @@ type BaseDatabase interface { BindSession(*sql.DB) error Session() *sql.DB - BindTx(*sql.Tx) error + BindTx(context.Context, *sql.Tx) error Transaction() BaseTx SetConnMaxLifetime(time.Duration) @@ -130,7 +130,7 @@ func (d *database) Session() *sql.DB { } // BindTx binds a *sql.Tx into *database -func (d *database) BindTx(t *sql.Tx) error { +func (d *database) BindTx(ctx context.Context, t *sql.Tx) error { d.sessMu.Lock() defer d.sessMu.Unlock() @@ -139,6 +139,7 @@ func (d *database) BindTx(t *sql.Tx) error { return err } + d.ctx = ctx d.txID = newTxID() return nil } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 1ddd7fe651940c39c300fe44c88705cdc191638e..cfc40805c8920a576bdc5a1d30ebe856d9aadff2 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -284,7 +284,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) { err := sess.Collection("artist").Truncate() assert.NoError(t, err) - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) defer tx.Close() @@ -1034,7 +1034,7 @@ func TestTransactionsAndRollback(t *testing.T) { sess := mustOpen() // Simple transaction that should not fail. - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) artist := tx.Collection("artist") @@ -1059,7 +1059,7 @@ func TestTransactionsAndRollback(t *testing.T) { assert.NoError(t, err) // Use another transaction. - tx, err = sess.NewTx() + tx, err = sess.NewTx(nil) assert.NoError(t, err) artist = tx.Collection("artist") @@ -1092,7 +1092,7 @@ func TestTransactionsAndRollback(t *testing.T) { assert.NoError(t, err) // Attempt to add some rows. - tx, err = sess.NewTx() + tx, err = sess.NewTx(nil) assert.NoError(t, err) artist = tx.Collection("artist") @@ -1123,7 +1123,7 @@ func TestTransactionsAndRollback(t *testing.T) { assert.NoError(t, err) // Attempt to add some rows. - tx, err = sess.NewTx() + tx, err = sess.NewTx(nil) assert.NoError(t, err) artist = tx.Collection("artist") @@ -1506,7 +1506,7 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) assert.NotZero(t, all) - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) assert.NotZero(t, tx) defer tx.Close() @@ -1556,7 +1556,7 @@ func TestExhaustConnectionPool(t *testing.T) { // Requesting a new transaction session. start := time.Now() tLogf("Tx: %d: NewTx", i) - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) if err != nil { tFatal(err) } diff --git a/internal/sqladapter/tx.go b/internal/sqladapter/tx.go index 3239f3104f13839753f56085ee62d48d54095566..9ab7a8b6abd23bd4318ea1a4dea3df74488fd9b3 100644 --- a/internal/sqladapter/tx.go +++ b/internal/sqladapter/tx.go @@ -22,6 +22,7 @@ package sqladapter import ( + "context" "database/sql" "sync/atomic" @@ -57,13 +58,6 @@ func NewTx(db Database) DatabaseTx { } } -func newTxWrapper(db Database) DatabaseTx { - return &txWrapper{ - Database: db, - BaseTx: db.Transaction(), - } -} - type sqlTx struct { *sql.Tx committed atomic.Value @@ -99,8 +93,8 @@ func (t *txWrapper) Rollback() error { } // RunTx creates a transaction context and runs fn within it. -func RunTx(d sqlbuilder.Database, fn func(tx sqlbuilder.Tx) error) error { - tx, err := d.NewTx() +func RunTx(d sqlbuilder.Database, ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + tx, err := d.NewTx(ctx) if err != nil { return err } diff --git a/lib/sqlbuilder/wrapper.go b/lib/sqlbuilder/wrapper.go index 2c2445d9b5b1b9bd8c5096920a06b5bd551fa6a0..05078c2d5d0180370cf78e7498a5d9f28ae17e06 100644 --- a/lib/sqlbuilder/wrapper.go +++ b/lib/sqlbuilder/wrapper.go @@ -22,6 +22,7 @@ package sqlbuilder import ( + "context" "database/sql" "fmt" "sync" @@ -50,6 +51,8 @@ type Backend interface { type Tx interface { Backend db.Tx + + Context() context.Context } // Database represents a Database which is capable of both creating @@ -59,14 +62,14 @@ type Database interface { // NewTx returns a new session that lives within a transaction. This session // is completely independent from its parent. - NewTx() (Tx, error) + NewTx(ctx context.Context) (Tx, error) // Tx creates a new transaction that is passed as context to the fn function. // The fn function defines a transaction operation. If the fn function // returns nil, the transaction is commited, otherwise the transaction is // rolled back. The transaction session is closed after the function exists, // regardless of the error value returned by fn. - Tx(fn func(sess Tx) error) error + Tx(ctx context.Context, fn func(sess Tx) error) error } // AdapterFuncMap is a struct that defines a set of functions that adapters diff --git a/mysql/database.go b/mysql/database.go index 068865b3f9ab25cb9c645ebe0351ce93931efefc..905f8fec8d262258c21b07d2272324c30d1bbeb6 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -22,6 +22,7 @@ package mysql import ( + "context" "strings" "sync" @@ -70,8 +71,11 @@ func (d *database) Open(connURL db.ConnectionURL) error { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + if ctx == nil { + ctx = d.Context() + } + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -112,9 +116,9 @@ func (d *database) open() error { connFn := func() error { sess, err := sql.Open("mysql", d.ConnectionURL().String()) if err == nil { - sess.SetConnMaxLifetime(connMaxLifetime) - sess.SetMaxIdleConns(maxIdleConns) - sess.SetMaxOpenConns(maxOpenConns) + sess.SetConnMaxLifetime(db.DefaultConnMaxLifetime) + sess.SetMaxIdleConns(db.DefaultMaxIdleConns) + sess.SetMaxOpenConns(db.DefaultMaxOpenConns) return d.BaseDatabase.BindSession(sess) } return err @@ -172,12 +176,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -187,9 +191,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() connFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil) if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/mysql/mysql.go b/mysql/mysql.go index 7f8da66efd0ae06889fc88cc555c0ec42313e6e6..eb4e752fccf73ef23f98c09d338e3ac47dc79762 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -23,7 +23,6 @@ package mysql // import "upper.io/db.v3/mysql" import ( "database/sql" - "time" "upper.io/db.v3" @@ -79,7 +78,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } @@ -109,30 +108,3 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) { } return d, nil } - -// SetConnMaxLifetime sets the default value to be passed to -// db.SetConnMaxLifetime. -func SetConnMaxLifetime(d time.Duration) { - connMaxLifetime = d -} - -// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns. -func SetMaxIdleConns(n int) { - if n < 0 { - n = 0 - } - maxIdleConns = n -} - -// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns. -// If the value of maxIdleConns is >= 0 and maxOpenConns is less than -// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns. -func SetMaxOpenConns(n int) { - if n < 0 { - n = 0 - } - if n > maxIdleConns { - maxIdleConns = n - } - maxOpenConns = n -} diff --git a/postgresql/database.go b/postgresql/database.go index a95975e62254ec84bf1854c94e202e29127d77b9..493947b393972764377d4b7853e74ed820e87156 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -22,6 +22,7 @@ package postgresql import ( + "context" "database/sql" "strings" "sync" @@ -69,8 +70,11 @@ func (d *database) Open(connURL db.ConnectionURL) error { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + if ctx == nil { + ctx = context.Background() + } + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -111,9 +115,9 @@ func (d *database) open() error { connFn := func() error { sess, err := sql.Open("postgres", d.ConnectionURL().String()) if err == nil { - sess.SetConnMaxLifetime(connMaxLifetime) - sess.SetMaxIdleConns(maxIdleConns) - sess.SetMaxOpenConns(maxOpenConns) + sess.SetConnMaxLifetime(db.DefaultConnMaxLifetime) + sess.SetMaxIdleConns(db.DefaultMaxIdleConns) + sess.SetMaxOpenConns(db.DefaultMaxOpenConns) return d.BaseDatabase.BindSession(sess) } return err @@ -172,12 +176,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -187,9 +191,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() connFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil) if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/postgresql/local_test.go b/postgresql/local_test.go index 1b7b5950c42d9d4177da62ae5b7eb5cfdca22b90..22f010cc02b6fc3dd3ea1310dc44178f45d77369 100644 --- a/postgresql/local_test.go +++ b/postgresql/local_test.go @@ -100,7 +100,7 @@ func TestIssue210(t *testing.T) { sess := mustOpen() defer sess.Close() - tx, err := sess.NewTx() + tx, err := sess.NewTx(nil) assert.NoError(t, err) for i := range list { diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index bbbc5ffff83d522603ea0482873ab79aba3bf56f..e48e051cf24f78b05fc20978b9fab4aca36c8621 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -23,7 +23,6 @@ package postgresql // import "upper.io/db.v3/postgresql" import ( "database/sql" - "time" "upper.io/db.v3" @@ -31,12 +30,6 @@ import ( "upper.io/db.v3/lib/sqlbuilder" ) -var ( - connMaxLifetime = db.DefaultConnMaxLifetime - maxIdleConns = db.DefaultMaxIdleConns - maxOpenConns = db.DefaultMaxOpenConns -) - const sqlDriver = `postgres` // Adapter is the public name of the adapter. @@ -79,7 +72,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } @@ -109,30 +102,3 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) { } return d, nil } - -// SetConnMaxLifetime sets the default value to be passed to -// db.SetConnMaxLifetime. -func SetConnMaxLifetime(d time.Duration) { - connMaxLifetime = d -} - -// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns. -func SetMaxIdleConns(n int) { - if n < 0 { - n = 0 - } - maxIdleConns = n -} - -// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns. -// If the value of maxIdleConns is >= 0 and maxOpenConns is less than -// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns. -func SetMaxOpenConns(n int) { - if n < 0 { - n = 0 - } - if n > maxIdleConns { - maxIdleConns = n - } - maxOpenConns = n -} diff --git a/ql/database.go b/ql/database.go index a71b951a804838073755b8611588d1fb2ece8021..58566a4822085f9e281671219b0b7032a0b635b8 100644 --- a/ql/database.go +++ b/ql/database.go @@ -113,7 +113,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err } @@ -145,8 +145,11 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + if ctx == nil { + ctx = d.Context() + } + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -272,12 +275,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -287,9 +290,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { defer clone.txMu.Unlock() openFn := func() error { - sqlTx, err := clone.BaseDatabase.Session().Begin() + sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil) if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/sqlite/database.go b/sqlite/database.go index f1efa16707bb6a78d6d5fe9c0f9c7ffa56079007..b04113cb244b9b2b90a1b0efd26d1de17627d6d6 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -22,6 +22,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -85,8 +86,8 @@ func (d *database) Open(connURL db.ConnectionURL) error { } // NewTx starts a transaction block. -func (d *database) NewTx() (sqlbuilder.Tx, error) { - nTx, err := d.NewLocalTransaction() +func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + nTx, err := d.NewLocalTransaction(ctx) if err != nil { return nil, err } @@ -215,12 +216,12 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Tx creates a transaction and passes it to the given function, if if the // function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error { - return sqladapter.RunTx(d, fn) +func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error { + return sqladapter.RunTx(d, ctx, fn) } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { +func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err @@ -232,7 +233,7 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { openFn := func() error { sqlTx, err := clone.BaseDatabase.Session().Begin() if err == nil { - return clone.BindTx(sqlTx) + return clone.BindTx(ctx, sqlTx) } return err } diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index 894459a2e5d91da18caeeb23a60edf322426861f..5c22c275a5212d01e637e691ab0c60ebc06d64f9 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -72,7 +72,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { } d.Builder = b - if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil { return nil, err }