From 00d0a05dc37c403e372008cc7b381536e5914219 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Sat, 17 Dec 2016 01:24:03 +0000
Subject: [PATCH] Add context to transactions

---
 internal/sqladapter/collection.go          |  4 +--
 internal/sqladapter/database.go            |  7 +++--
 internal/sqladapter/testing/adapter.go.tpl | 14 ++++-----
 internal/sqladapter/tx.go                  | 12 ++------
 lib/sqlbuilder/wrapper.go                  |  7 +++--
 mysql/database.go                          | 24 +++++++++------
 mysql/mysql.go                             | 30 +-----------------
 postgresql/database.go                     | 24 +++++++++------
 postgresql/local_test.go                   |  2 +-
 postgresql/postgresql.go                   | 36 +---------------------
 ql/database.go                             | 19 +++++++-----
 sqlite/database.go                         | 13 ++++----
 sqlite/sqlite.go                           |  2 +-
 13 files changed, 71 insertions(+), 123 deletions(-)

diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go
index 320209da..b6a1058f 100644
--- a/internal/sqladapter/collection.go
+++ b/internal/sqladapter/collection.go
@@ -76,12 +76,12 @@ func (c *collection) InsertReturning(item interface{}) error {
 	inTx := false
 
 	if currTx := c.p.Database().Transaction(); currTx != nil {
-		tx = newTxWrapper(c.p.Database())
+		tx = NewTx(c.p.Database())
 		inTx = true
 	} else {
 		// Not within a transaction, let's create one.
 		var err error
-		tx, err = c.p.Database().NewLocalTransaction()
+		tx, err = c.p.Database().NewLocalTransaction(c.p.Database().Context())
 		if err != nil {
 			return err
 		}
diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go
index 4c6cd6d8..5f233092 100644
--- a/internal/sqladapter/database.go
+++ b/internal/sqladapter/database.go
@@ -54,7 +54,7 @@ type PartialDatabase interface {
 	ConnectionURL() db.ConnectionURL
 
 	Err(in error) (out error)
-	NewLocalTransaction() (DatabaseTx, error)
+	NewLocalTransaction(ctx context.Context) (DatabaseTx, error)
 }
 
 // BaseDatabase defines the methods provided by sqladapter that do not have to
@@ -74,7 +74,7 @@ type BaseDatabase interface {
 	BindSession(*sql.DB) error
 	Session() *sql.DB
 
-	BindTx(*sql.Tx) error
+	BindTx(context.Context, *sql.Tx) error
 	Transaction() BaseTx
 
 	SetConnMaxLifetime(time.Duration)
@@ -130,7 +130,7 @@ func (d *database) Session() *sql.DB {
 }
 
 // BindTx binds a *sql.Tx into *database
-func (d *database) BindTx(t *sql.Tx) error {
+func (d *database) BindTx(ctx context.Context, t *sql.Tx) error {
 	d.sessMu.Lock()
 	defer d.sessMu.Unlock()
 
@@ -139,6 +139,7 @@ func (d *database) BindTx(t *sql.Tx) error {
 		return err
 	}
 
+	d.ctx = ctx
 	d.txID = newTxID()
 	return nil
 }
diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index 1ddd7fe6..cfc40805 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -284,7 +284,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) {
 	err := sess.Collection("artist").Truncate()
 	assert.NoError(t, err)
 
-	tx, err := sess.NewTx()
+	tx, err := sess.NewTx(nil)
 	assert.NoError(t, err)
 	defer tx.Close()
 
@@ -1034,7 +1034,7 @@ func TestTransactionsAndRollback(t *testing.T) {
 	sess := mustOpen()
 
 	// Simple transaction that should not fail.
-	tx, err := sess.NewTx()
+	tx, err := sess.NewTx(nil)
 	assert.NoError(t, err)
 
 	artist := tx.Collection("artist")
@@ -1059,7 +1059,7 @@ func TestTransactionsAndRollback(t *testing.T) {
 	assert.NoError(t, err)
 
 	// Use another transaction.
-	tx, err = sess.NewTx()
+	tx, err = sess.NewTx(nil)
 	assert.NoError(t, err)
 
 	artist = tx.Collection("artist")
@@ -1092,7 +1092,7 @@ func TestTransactionsAndRollback(t *testing.T) {
 	assert.NoError(t, err)
 
 	// Attempt to add some rows.
-	tx, err = sess.NewTx()
+	tx, err = sess.NewTx(nil)
 	assert.NoError(t, err)
 
 	artist = tx.Collection("artist")
@@ -1123,7 +1123,7 @@ func TestTransactionsAndRollback(t *testing.T) {
 	assert.NoError(t, err)
 
 	// Attempt to add some rows.
-	tx, err = sess.NewTx()
+	tx, err = sess.NewTx(nil)
 	assert.NoError(t, err)
 
 	artist = tx.Collection("artist")
@@ -1506,7 +1506,7 @@ func TestBuilder(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NotZero(t, all)
 
-	tx, err := sess.NewTx()
+	tx, err := sess.NewTx(nil)
 	assert.NoError(t, err)
 	assert.NotZero(t, tx)
 	defer tx.Close()
@@ -1556,7 +1556,7 @@ func TestExhaustConnectionPool(t *testing.T) {
 			// Requesting a new transaction session.
 			start := time.Now()
 			tLogf("Tx: %d: NewTx", i)
-			tx, err := sess.NewTx()
+			tx, err := sess.NewTx(nil)
 			if err != nil {
 				tFatal(err)
 			}
diff --git a/internal/sqladapter/tx.go b/internal/sqladapter/tx.go
index 3239f310..9ab7a8b6 100644
--- a/internal/sqladapter/tx.go
+++ b/internal/sqladapter/tx.go
@@ -22,6 +22,7 @@
 package sqladapter
 
 import (
+	"context"
 	"database/sql"
 	"sync/atomic"
 
@@ -57,13 +58,6 @@ func NewTx(db Database) DatabaseTx {
 	}
 }
 
-func newTxWrapper(db Database) DatabaseTx {
-	return &txWrapper{
-		Database: db,
-		BaseTx:   db.Transaction(),
-	}
-}
-
 type sqlTx struct {
 	*sql.Tx
 	committed atomic.Value
@@ -99,8 +93,8 @@ func (t *txWrapper) Rollback() error {
 }
 
 // RunTx creates a transaction context and runs fn within it.
-func RunTx(d sqlbuilder.Database, fn func(tx sqlbuilder.Tx) error) error {
-	tx, err := d.NewTx()
+func RunTx(d sqlbuilder.Database, ctx context.Context, fn func(tx sqlbuilder.Tx) error) error {
+	tx, err := d.NewTx(ctx)
 	if err != nil {
 		return err
 	}
diff --git a/lib/sqlbuilder/wrapper.go b/lib/sqlbuilder/wrapper.go
index 2c2445d9..05078c2d 100644
--- a/lib/sqlbuilder/wrapper.go
+++ b/lib/sqlbuilder/wrapper.go
@@ -22,6 +22,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 	"fmt"
 	"sync"
@@ -50,6 +51,8 @@ type Backend interface {
 type Tx interface {
 	Backend
 	db.Tx
+
+	Context() context.Context
 }
 
 // Database represents a Database which is capable of both creating
@@ -59,14 +62,14 @@ type Database interface {
 
 	// NewTx returns a new session that lives within a transaction. This session
 	// is completely independent from its parent.
-	NewTx() (Tx, error)
+	NewTx(ctx context.Context) (Tx, error)
 
 	// Tx creates a new transaction that is passed as context to the fn function.
 	// The fn function defines a transaction operation.  If the fn function
 	// returns nil, the transaction is commited, otherwise the transaction is
 	// rolled back.  The transaction session is closed after the function exists,
 	// regardless of the error value returned by fn.
-	Tx(fn func(sess Tx) error) error
+	Tx(ctx context.Context, fn func(sess Tx) error) error
 }
 
 // AdapterFuncMap is a struct that defines a set of functions that adapters
diff --git a/mysql/database.go b/mysql/database.go
index 068865b3..905f8fec 100644
--- a/mysql/database.go
+++ b/mysql/database.go
@@ -22,6 +22,7 @@
 package mysql
 
 import (
+	"context"
 	"strings"
 	"sync"
 
@@ -70,8 +71,11 @@ func (d *database) Open(connURL db.ConnectionURL) error {
 }
 
 // NewTx starts a transaction block.
-func (d *database) NewTx() (sqlbuilder.Tx, error) {
-	nTx, err := d.NewLocalTransaction()
+func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) {
+	if ctx == nil {
+		ctx = d.Context()
+	}
+	nTx, err := d.NewLocalTransaction(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -112,9 +116,9 @@ func (d *database) open() error {
 	connFn := func() error {
 		sess, err := sql.Open("mysql", d.ConnectionURL().String())
 		if err == nil {
-			sess.SetConnMaxLifetime(connMaxLifetime)
-			sess.SetMaxIdleConns(maxIdleConns)
-			sess.SetMaxOpenConns(maxOpenConns)
+			sess.SetConnMaxLifetime(db.DefaultConnMaxLifetime)
+			sess.SetMaxIdleConns(db.DefaultMaxIdleConns)
+			sess.SetMaxOpenConns(db.DefaultMaxOpenConns)
 			return d.BaseDatabase.BindSession(sess)
 		}
 		return err
@@ -172,12 +176,12 @@ func (d *database) NewLocalCollection(name string) db.Collection {
 
 // Tx creates a transaction and passes it to the given function, if if the
 // function returns no error then the transaction is commited.
-func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error {
-	return sqladapter.RunTx(d, fn)
+func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error {
+	return sqladapter.RunTx(d, ctx, fn)
 }
 
 // NewLocalTransaction allows sqladapter start a transaction block.
-func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
+func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) {
 	clone, err := d.clone()
 	if err != nil {
 		return nil, err
@@ -187,9 +191,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
 	defer clone.txMu.Unlock()
 
 	connFn := func() error {
-		sqlTx, err := clone.BaseDatabase.Session().Begin()
+		sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil)
 		if err == nil {
-			return clone.BindTx(sqlTx)
+			return clone.BindTx(ctx, sqlTx)
 		}
 		return err
 	}
diff --git a/mysql/mysql.go b/mysql/mysql.go
index 7f8da66e..eb4e752f 100644
--- a/mysql/mysql.go
+++ b/mysql/mysql.go
@@ -23,7 +23,6 @@ package mysql // import "upper.io/db.v3/mysql"
 
 import (
 	"database/sql"
-	"time"
 
 	"upper.io/db.v3"
 
@@ -79,7 +78,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
 	}
 	d.Builder = b
 
-	if err := d.BaseDatabase.BindTx(sqlTx); err != nil {
+	if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil {
 		return nil, err
 	}
 
@@ -109,30 +108,3 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) {
 	}
 	return d, nil
 }
-
-// SetConnMaxLifetime sets the default value to be passed to
-// db.SetConnMaxLifetime.
-func SetConnMaxLifetime(d time.Duration) {
-	connMaxLifetime = d
-}
-
-// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns.
-func SetMaxIdleConns(n int) {
-	if n < 0 {
-		n = 0
-	}
-	maxIdleConns = n
-}
-
-// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns.
-// If the value of maxIdleConns is >= 0 and maxOpenConns is less than
-// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns.
-func SetMaxOpenConns(n int) {
-	if n < 0 {
-		n = 0
-	}
-	if n > maxIdleConns {
-		maxIdleConns = n
-	}
-	maxOpenConns = n
-}
diff --git a/postgresql/database.go b/postgresql/database.go
index a95975e6..493947b3 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -22,6 +22,7 @@
 package postgresql
 
 import (
+	"context"
 	"database/sql"
 	"strings"
 	"sync"
@@ -69,8 +70,11 @@ func (d *database) Open(connURL db.ConnectionURL) error {
 }
 
 // NewTx starts a transaction block.
-func (d *database) NewTx() (sqlbuilder.Tx, error) {
-	nTx, err := d.NewLocalTransaction()
+func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) {
+	if ctx == nil {
+		ctx = context.Background()
+	}
+	nTx, err := d.NewLocalTransaction(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -111,9 +115,9 @@ func (d *database) open() error {
 	connFn := func() error {
 		sess, err := sql.Open("postgres", d.ConnectionURL().String())
 		if err == nil {
-			sess.SetConnMaxLifetime(connMaxLifetime)
-			sess.SetMaxIdleConns(maxIdleConns)
-			sess.SetMaxOpenConns(maxOpenConns)
+			sess.SetConnMaxLifetime(db.DefaultConnMaxLifetime)
+			sess.SetMaxIdleConns(db.DefaultMaxIdleConns)
+			sess.SetMaxOpenConns(db.DefaultMaxOpenConns)
 			return d.BaseDatabase.BindSession(sess)
 		}
 		return err
@@ -172,12 +176,12 @@ func (d *database) NewLocalCollection(name string) db.Collection {
 
 // Tx creates a transaction and passes it to the given function, if if the
 // function returns no error then the transaction is commited.
-func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error {
-	return sqladapter.RunTx(d, fn)
+func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error {
+	return sqladapter.RunTx(d, ctx, fn)
 }
 
 // NewLocalTransaction allows sqladapter start a transaction block.
-func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
+func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) {
 	clone, err := d.clone()
 	if err != nil {
 		return nil, err
@@ -187,9 +191,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
 	defer clone.txMu.Unlock()
 
 	connFn := func() error {
-		sqlTx, err := clone.BaseDatabase.Session().Begin()
+		sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil)
 		if err == nil {
-			return clone.BindTx(sqlTx)
+			return clone.BindTx(ctx, sqlTx)
 		}
 		return err
 	}
diff --git a/postgresql/local_test.go b/postgresql/local_test.go
index 1b7b5950..22f010cc 100644
--- a/postgresql/local_test.go
+++ b/postgresql/local_test.go
@@ -100,7 +100,7 @@ func TestIssue210(t *testing.T) {
 	sess := mustOpen()
 	defer sess.Close()
 
-	tx, err := sess.NewTx()
+	tx, err := sess.NewTx(nil)
 	assert.NoError(t, err)
 
 	for i := range list {
diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go
index bbbc5fff..e48e051c 100644
--- a/postgresql/postgresql.go
+++ b/postgresql/postgresql.go
@@ -23,7 +23,6 @@ package postgresql // import "upper.io/db.v3/postgresql"
 
 import (
 	"database/sql"
-	"time"
 
 	"upper.io/db.v3"
 
@@ -31,12 +30,6 @@ import (
 	"upper.io/db.v3/lib/sqlbuilder"
 )
 
-var (
-	connMaxLifetime = db.DefaultConnMaxLifetime
-	maxIdleConns    = db.DefaultMaxIdleConns
-	maxOpenConns    = db.DefaultMaxOpenConns
-)
-
 const sqlDriver = `postgres`
 
 // Adapter is the public name of the adapter.
@@ -79,7 +72,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
 	}
 	d.Builder = b
 
-	if err := d.BaseDatabase.BindTx(sqlTx); err != nil {
+	if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil {
 		return nil, err
 	}
 
@@ -109,30 +102,3 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) {
 	}
 	return d, nil
 }
-
-// SetConnMaxLifetime sets the default value to be passed to
-// db.SetConnMaxLifetime.
-func SetConnMaxLifetime(d time.Duration) {
-	connMaxLifetime = d
-}
-
-// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns.
-func SetMaxIdleConns(n int) {
-	if n < 0 {
-		n = 0
-	}
-	maxIdleConns = n
-}
-
-// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns.
-// If the value of maxIdleConns is >= 0 and maxOpenConns is less than
-// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns.
-func SetMaxOpenConns(n int) {
-	if n < 0 {
-		n = 0
-	}
-	if n > maxIdleConns {
-		maxIdleConns = n
-	}
-	maxOpenConns = n
-}
diff --git a/ql/database.go b/ql/database.go
index a71b951a..58566a48 100644
--- a/ql/database.go
+++ b/ql/database.go
@@ -113,7 +113,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
 	}
 	d.Builder = b
 
-	if err := d.BaseDatabase.BindTx(sqlTx); err != nil {
+	if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil {
 		return nil, err
 	}
 
@@ -145,8 +145,11 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) {
 }
 
 // NewTx starts a transaction block.
-func (d *database) NewTx() (sqlbuilder.Tx, error) {
-	nTx, err := d.NewLocalTransaction()
+func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) {
+	if ctx == nil {
+		ctx = d.Context()
+	}
+	nTx, err := d.NewLocalTransaction(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -272,12 +275,12 @@ func (d *database) NewLocalCollection(name string) db.Collection {
 
 // Tx creates a transaction and passes it to the given function, if if the
 // function returns no error then the transaction is commited.
-func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error {
-	return sqladapter.RunTx(d, fn)
+func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error {
+	return sqladapter.RunTx(d, ctx, fn)
 }
 
 // NewLocalTransaction allows sqladapter start a transaction block.
-func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
+func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) {
 	clone, err := d.clone()
 	if err != nil {
 		return nil, err
@@ -287,9 +290,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
 	defer clone.txMu.Unlock()
 
 	openFn := func() error {
-		sqlTx, err := clone.BaseDatabase.Session().Begin()
+		sqlTx, err := clone.BaseDatabase.Session().BeginTx(ctx, nil)
 		if err == nil {
-			return clone.BindTx(sqlTx)
+			return clone.BindTx(ctx, sqlTx)
 		}
 		return err
 	}
diff --git a/sqlite/database.go b/sqlite/database.go
index f1efa167..b04113cb 100644
--- a/sqlite/database.go
+++ b/sqlite/database.go
@@ -22,6 +22,7 @@
 package sqlite
 
 import (
+	"context"
 	"database/sql"
 	"errors"
 	"fmt"
@@ -85,8 +86,8 @@ func (d *database) Open(connURL db.ConnectionURL) error {
 }
 
 // NewTx starts a transaction block.
-func (d *database) NewTx() (sqlbuilder.Tx, error) {
-	nTx, err := d.NewLocalTransaction()
+func (d *database) NewTx(ctx context.Context) (sqlbuilder.Tx, error) {
+	nTx, err := d.NewLocalTransaction(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -215,12 +216,12 @@ func (d *database) NewLocalCollection(name string) db.Collection {
 
 // Tx creates a transaction and passes it to the given function, if if the
 // function returns no error then the transaction is commited.
-func (d *database) Tx(fn func(tx sqlbuilder.Tx) error) error {
-	return sqladapter.RunTx(d, fn)
+func (d *database) Tx(ctx context.Context, fn func(tx sqlbuilder.Tx) error) error {
+	return sqladapter.RunTx(d, ctx, fn)
 }
 
 // NewLocalTransaction allows sqladapter start a transaction block.
-func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
+func (d *database) NewLocalTransaction(ctx context.Context) (sqladapter.DatabaseTx, error) {
 	clone, err := d.clone()
 	if err != nil {
 		return nil, err
@@ -232,7 +233,7 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
 	openFn := func() error {
 		sqlTx, err := clone.BaseDatabase.Session().Begin()
 		if err == nil {
-			return clone.BindTx(sqlTx)
+			return clone.BindTx(ctx, sqlTx)
 		}
 		return err
 	}
diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go
index 894459a2..5c22c275 100644
--- a/sqlite/sqlite.go
+++ b/sqlite/sqlite.go
@@ -72,7 +72,7 @@ func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
 	}
 	d.Builder = b
 
-	if err := d.BaseDatabase.BindTx(sqlTx); err != nil {
+	if err := d.BaseDatabase.BindTx(d.Context(), sqlTx); err != nil {
 		return nil, err
 	}
 
-- 
GitLab