From 9c2a9f19d9568abad75d878c3a0264061f182c41 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Sun, 27 Nov 2016 23:41:32 -0600
Subject: [PATCH] Make prepared statement cache optional, refactor and improve
 tests

---
 config.go                                  |  44 +++++--
 internal/sqladapter/collection.go          |   5 +
 internal/sqladapter/database.go            | 134 ++++++++++++++++-----
 internal/sqladapter/statement.go           |  60 +++++----
 internal/sqladapter/testing/adapter.go.tpl |  41 +++++--
 mysql/database.go                          |   9 +-
 postgresql/database.go                     |   8 +-
 ql/Makefile                                |   2 +-
 ql/database.go                             |  43 ++++---
 sqlite/database.go                         |  44 ++++---
 sqlite/tx.go                               |  18 +--
 11 files changed, 263 insertions(+), 145 deletions(-)

diff --git a/config.go b/config.go
index f0e34e82..2c21554e 100644
--- a/config.go
+++ b/config.go
@@ -37,10 +37,18 @@ type Settings interface {
 	SetLogger(Logger)
 	// Returns the currently configured logger.
 	Logger() Logger
+
+	// SetPreparedStatementCache enables or disables the prepared statement
+	// cache.
+	SetPreparedStatementCache(bool)
+	// PreparedStatementCacheEnabled returns true if the prepared statement cache
+	// is enabled, false otherwise.
+	PreparedStatementCacheEnabled() bool
 }
 
 type conf struct {
-	loggingEnabled uint32
+	loggingEnabled                uint32
+	preparedStatementCacheEnabled uint32
 
 	queryLogger   Logger
 	queryLoggerMu sync.RWMutex
@@ -65,20 +73,38 @@ func (c *conf) SetLogger(lg Logger) {
 	c.queryLogger = lg
 }
 
-func (c *conf) SetLogging(value bool) {
+func (c *conf) binaryOption(opt *uint32) bool {
+	if atomic.LoadUint32(opt) == 1 {
+		return true
+	}
+	return false
+}
+
+func (c *conf) setBinaryOption(opt *uint32, value bool) {
 	if value {
-		atomic.StoreUint32(&c.loggingEnabled, 1)
+		atomic.StoreUint32(opt, 1)
 		return
 	}
-	atomic.StoreUint32(&c.loggingEnabled, 0)
+	atomic.StoreUint32(opt, 0)
+}
+
+func (c *conf) SetLogging(value bool) {
+	c.setBinaryOption(&c.loggingEnabled, value)
 }
 
 func (c *conf) LoggingEnabled() bool {
-	if v := atomic.LoadUint32(&c.loggingEnabled); v == 1 {
-		return true
-	}
-	return false
+	return c.binaryOption(&c.loggingEnabled)
+}
+
+func (c *conf) SetPreparedStatementCache(value bool) {
+	c.setBinaryOption(&c.preparedStatementCacheEnabled, value)
+}
+
+func (c *conf) PreparedStatementCacheEnabled() bool {
+	return c.binaryOption(&c.preparedStatementCacheEnabled)
 }
 
 // Conf provides global configuration settings for upper-db.
-var Conf Settings = &conf{}
+var Conf Settings = &conf{
+	preparedStatementCacheEnabled: 0,
+}
diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go
index bb21e27a..004c0049 100644
--- a/internal/sqladapter/collection.go
+++ b/internal/sqladapter/collection.go
@@ -3,6 +3,7 @@ package sqladapter
 import (
 	"fmt"
 	"reflect"
+	"sync"
 
 	"upper.io/db.v2"
 	"upper.io/db.v2/internal/sqladapter/exql"
@@ -35,6 +36,7 @@ type BaseCollection interface {
 type collection struct {
 	p  PartialCollection
 	pk []string
+	mu sync.Mutex
 }
 
 // NewBaseCollection returns a collection with basic methods.
@@ -68,6 +70,9 @@ func (c *collection) Exists() bool {
 
 // InsertReturning inserts an item and updates the given variable reference.
 func (c *collection) InsertReturning(item interface{}) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	if reflect.TypeOf(item).Kind() != reflect.Ptr {
 		return fmt.Errorf("Expecting a pointer to map or string but got %T", item)
 	}
diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go
index d3bf392f..20c9a0ad 100644
--- a/internal/sqladapter/database.go
+++ b/internal/sqladapter/database.go
@@ -27,7 +27,7 @@ type HasCleanUp interface {
 
 // HasStatementExec allows the adapter to have its own exec statement.
 type HasStatementExec interface {
-	StatementExec(stmt *sql.Stmt, args ...interface{}) (sql.Result, error)
+	StatementExec(query string, args ...interface{}) (sql.Result, error)
 }
 
 // Database represents a SQL database.
@@ -77,6 +77,8 @@ type BaseDatabase interface {
 	SetConnMaxLifetime(time.Duration)
 	SetMaxIdleConns(int)
 	SetMaxOpenConns(int)
+
+	BindClone(PartialDatabase) (BaseDatabase, error)
 }
 
 // NewBaseDatabase provides a BaseDatabase given a PartialDatabase
@@ -102,6 +104,8 @@ type database struct {
 	sess   *sql.DB
 	sessMu sync.Mutex
 
+	psMu sync.Mutex
+
 	sessID uint64
 	txID   uint64
 
@@ -217,6 +221,20 @@ func (d *database) ClearCache() {
 	}
 }
 
+// BindClone binds a clone that is linked to the current
+// session. This is commonly done before creating a transaction
+// session.
+func (d *database) BindClone(p PartialDatabase) (BaseDatabase, error) {
+	nd := NewBaseDatabase(p).(*database)
+	nd.name = d.name
+	nd.sess = d.sess
+	if err := nd.Ping(); err != nil {
+		return nil, err
+	}
+	nd.sessID = newSessionID()
+	return nd, nil
+}
+
 // Close terminates the current database session
 func (d *database) Close() error {
 	defer func() {
@@ -229,6 +247,7 @@ func (d *database) Close() error {
 		if cleaner, ok := d.PartialDatabase.(HasCleanUp); ok {
 			cleaner.CleanUp()
 		}
+
 		d.cachedCollections.Clear()
 		d.cachedStatements.Clear() // Closes prepared statements as well.
 
@@ -240,6 +259,7 @@ func (d *database) Close() error {
 
 		if !tx.Committed() {
 			tx.Rollback()
+			return nil
 		}
 	}
 	return nil
@@ -295,18 +315,33 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res
 		}(time.Now())
 	}
 
-	var p *Stmt
-	if p, query, err = d.prepareStatement(stmt); err != nil {
-		return nil, err
+	if execer, ok := d.PartialDatabase.(HasStatementExec); ok {
+		query = d.compileStatement(stmt)
+		res, err = execer.StatementExec(query, args...)
+		return
 	}
-	defer p.Close()
 
-	if execer, ok := d.PartialDatabase.(HasStatementExec); ok {
-		res, err = execer.StatementExec(p.Stmt, args...)
+	tx := d.Transaction()
+
+	if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
+		var p *Stmt
+		if p, query, err = d.prepareStatement(stmt); err != nil {
+			return nil, err
+		}
+		defer p.Close()
+
+		res, err = p.Exec(args...)
 		return
 	}
 
-	res, err = p.Exec(args...)
+	query = d.compileStatement(stmt)
+
+	if tx != nil {
+		res, err = tx.(*sqlTx).Exec(query, args...)
+		return
+	}
+
+	res, err = d.sess.Exec(query, args...)
 	return
 }
 
@@ -328,14 +363,28 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro
 		}(time.Now())
 	}
 
-	var p *Stmt
-	if p, query, err = d.prepareStatement(stmt); err != nil {
-		return nil, err
+	tx := d.Transaction()
+
+	if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
+		var p *Stmt
+		if p, query, err = d.prepareStatement(stmt); err != nil {
+			return nil, err
+		}
+		defer p.Close()
+
+		rows, err = p.Query(args...)
+		return
+	}
+
+	query = d.compileStatement(stmt)
+	if tx != nil {
+		rows, err = tx.(*sqlTx).Query(query, args...)
+		return
 	}
-	defer p.Close()
 
-	rows, err = p.Query(args...)
+	rows, err = d.sess.Query(query, args...)
 	return
+
 }
 
 // StatementQueryRow compiles and executes a statement that returns at most one
@@ -357,13 +406,26 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{})
 		}(time.Now())
 	}
 
-	var p *Stmt
-	if p, query, err = d.prepareStatement(stmt); err != nil {
-		return nil, err
+	tx := d.Transaction()
+
+	if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
+		var p *Stmt
+		if p, query, err = d.prepareStatement(stmt); err != nil {
+			return nil, err
+		}
+		defer p.Close()
+
+		row = p.QueryRow(args...)
+		return
 	}
-	defer p.Close()
 
-	row, err = p.QueryRow(args...), nil
+	query = d.compileStatement(stmt)
+	if tx != nil {
+		row = tx.(*sqlTx).QueryRow(query, args...)
+		return
+	}
+
+	row = d.sess.QueryRow(query, args...)
 	return
 }
 
@@ -376,14 +438,19 @@ func (d *database) Driver() interface{} {
 	return d.sess
 }
 
-// prepareStatement converts a *exql.Statement representation into an actual
-// *sql.Stmt.  This method will attempt to used a cached prepared statement, if
-// available.
+// compileStatement compiles the given statement into a string.
+func (d *database) compileStatement(stmt *exql.Statement) string {
+	return d.PartialDatabase.CompileStatement(stmt)
+}
+
+// prepareStatement compiles a query and tries to use previously generated
+// statement.
 func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) {
 	d.sessMu.Lock()
 	defer d.sessMu.Unlock()
 
-	if d.sess == nil && d.Transaction() == nil {
+	sess, tx := d.sess, d.Transaction()
+	if sess == nil && tx == nil {
 		return nil, "", db.ErrNotConnected
 	}
 
@@ -396,22 +463,23 @@ func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error)
 		}
 	}
 
-	// Plain SQL query.
-	query := d.PartialDatabase.CompileStatement(stmt)
-
-	sqlStmt, err := func() (*sql.Stmt, error) {
-		if d.Transaction() != nil {
-			return d.Transaction().(*sqlTx).Prepare(query)
+	query := d.compileStatement(stmt)
+	sqlStmt, err := func(query *string) (*sql.Stmt, error) {
+		if tx != nil {
+			return tx.(*sqlTx).Prepare(*query)
 		}
-		return d.sess.Prepare(query)
-	}()
+		return sess.Prepare(*query)
+	}(&query)
 	if err != nil {
-		return nil, query, err
+		return nil, "", err
 	}
 
-	p := NewStatement(sqlStmt, query)
+	p, err := NewStatement(sqlStmt, query).Open()
+	if err != nil {
+		return nil, query, err
+	}
 	d.cachedStatements.Write(stmt, p)
-	return p, query, nil
+	return p, p.query, nil
 }
 
 var waitForConnMu sync.Mutex
diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go
index 9abee8f4..17e7c6d7 100644
--- a/internal/sqladapter/statement.go
+++ b/internal/sqladapter/statement.go
@@ -3,6 +3,7 @@ package sqladapter
 import (
 	"database/sql"
 	"errors"
+	"sync"
 	"sync/atomic"
 )
 
@@ -10,21 +11,16 @@ var (
 	activeStatements int64
 )
 
-// NumActiveStatements returns the number of prepared statements in use at any
-// point.
-func NumActiveStatements() int64 {
-	return atomic.LoadInt64(&activeStatements)
-}
-
 // Stmt represents a *sql.Stmt that is cached and provides the
 // OnPurge method to allow it to clean after itself.
 type Stmt struct {
 	*sql.Stmt
 
 	query string
+	mu    sync.Mutex
 
 	count int64
-	dead  int32
+	dead  bool
 }
 
 // NewStatement creates an returns an opened statement
@@ -32,44 +28,58 @@ func NewStatement(stmt *sql.Stmt, query string) *Stmt {
 	s := &Stmt{
 		Stmt:  stmt,
 		query: query,
-		count: 1,
 	}
-	// Increment active statements counter.
 	atomic.AddInt64(&activeStatements, 1)
 	return s
 }
 
 // Open marks the statement as in-use
 func (c *Stmt) Open() (*Stmt, error) {
-	if atomic.LoadInt32(&c.dead) > 0 {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if c.dead {
 		return nil, errors.New("statement is dead")
 	}
-	atomic.AddInt64(&c.count, 1)
+
+	c.count++
 	return c, nil
 }
 
 // Close closes the underlying statement if no other go-routine is using it.
-func (c *Stmt) Close() (err error) {
-	if atomic.AddInt64(&c.count, -1) > 0 {
-		// If this counter is more than 0 then there are other goroutines using
-		// this statement so we don't want to close it for real.
-		return
-	}
+func (c *Stmt) Close() error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	c.count--
 
-	if atomic.LoadInt32(&c.dead) > 0 && atomic.LoadInt64(&c.count) <= 0 {
+	return c.checkClose()
+}
+
+func (c *Stmt) checkClose() error {
+	if c.dead && c.count == 0 {
 		// Statement is dead and we can close it for real.
-		err = c.Stmt.Close()
+		err := c.Stmt.Close()
+		if err != nil {
+			return err
+		}
 		// Reduce active statements counter.
 		atomic.AddInt64(&activeStatements, -1)
 	}
-	return
+	return nil
 }
 
 // OnPurge marks the statement as ready to be cleaned up.
 func (c *Stmt) OnPurge() {
-	// Mark as dead, you can continue using it but it will be closed for real
-	// when c.count reaches 0.
-	atomic.StoreInt32(&c.dead, 1)
-	// Call Close again to make sure we're closing the statement.
-	c.Close()
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	c.dead = true
+	c.checkClose()
+}
+
+// NumActiveStatements returns the global number of prepared statements in use
+// at any point.
+func NumActiveStatements() int64 {
+	return atomic.LoadInt64(&activeStatements)
 }
diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index 6d810585..0f4cae94 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -78,6 +78,9 @@ func TestOpenMustSucceed(t *testing.T) {
 func TestPreparedStatementsCache(t *testing.T) {
 	sess := mustOpen()
 
+	db.Conf.SetPreparedStatementCache(true)
+	defer db.Conf.SetPreparedStatementCache(false)
+
 	var tMu sync.Mutex
 	tFatal := func(err error) {
 		tMu.Lock()
@@ -88,8 +91,8 @@ func TestPreparedStatementsCache(t *testing.T) {
 	// This limit was chosen because, by default, MySQL accepts 16k statements
 	// and dies. See https://github.com/upper/db/issues/287
 	limit := 20000
-
 	var wg sync.WaitGroup
+
 	for i := 0; i < limit; i++ {
 		wg.Add(1)
 		go func(i int) {
@@ -108,17 +111,7 @@ func TestPreparedStatementsCache(t *testing.T) {
 
 	// Concurrent Insert can open many connections on MySQL / PostgreSQL, this
 	// sets a limit to them.
-	maxOpenConns := 100
-	if Adapter == "sqlite" {
-		// We can't use sqlite3 for multiple writes concurrently.
-		// https://github.com/mattn/go-sqlite3#faq
-		//
-		// The right thing here would be using bulk insertion, but that's not what
-		// we're testing.
-		limit = 10
-	}
-	sess.SetMaxOpenConns(maxOpenConns)
-	log.Printf("limit: %v, maxOpenConns: %v", limit, maxOpenConns)
+	//sess.SetMaxOpenConns(100)
 
 	for i := 0; i < limit; i++ {
 		wg.Add(1)
@@ -126,7 +119,7 @@ func TestPreparedStatementsCache(t *testing.T) {
 			defer wg.Done()
 			// The same prepared query on every iteration.
 			_, err := sess.Collection("artist").Insert(artistType{
-        Name: fmt.Sprintf("artist-%d", i%200),
+        Name: fmt.Sprintf("artist-%d", i),
       })
 			if err != nil {
 				tFatal(err)
@@ -135,6 +128,23 @@ func TestPreparedStatementsCache(t *testing.T) {
 	}
 	wg.Wait()
 
+	// Insert returning creates a transaction.
+	for i := 0; i < limit; i++ {
+		wg.Add(1)
+		go func(i int) {
+			defer wg.Done()
+			// The same prepared query on every iteration.
+			artist := artistType{
+        Name: fmt.Sprintf("artist-%d", i),
+      }
+			err := sess.Collection("artist").InsertReturning(&artist)
+			if err != nil {
+				tFatal(err)
+			}
+		}(i)
+	}
+	wg.Wait()
+
 	// Removing the limit.
 	sess.SetMaxOpenConns(0)
 
@@ -1008,6 +1018,7 @@ func TestCompositeKeys(t *testing.T) {
 
 // Attempts to test database transactions.
 func TestTransactionsAndRollback(t *testing.T) {
+
 	if Adapter == "ql" {
 		t.Skip("Currently not supported.")
 	}
@@ -1036,8 +1047,12 @@ func TestTransactionsAndRollback(t *testing.T) {
 	err = tx.Close()
 	assert.NoError(t, err)
 
+	err = tx.Close()
+	assert.NoError(t, err)
+
 	// Use another transaction.
 	tx, err = sess.NewTx()
+	assert.NoError(t, err)
 
 	artist = tx.Collection("artist")
 
diff --git a/mysql/database.go b/mysql/database.go
index f8c2479c..a00d9a7c 100644
--- a/mysql/database.go
+++ b/mysql/database.go
@@ -133,17 +133,16 @@ func (d *database) clone() (*database, error) {
 		return nil, err
 	}
 
-	clone.BaseDatabase = sqladapter.NewBaseDatabase(clone)
-
-	b, err := sqlbuilder.WithSession(clone.BaseDatabase, template)
+	clone.BaseDatabase, err = d.BindClone(clone)
 	if err != nil {
 		return nil, err
 	}
-	clone.Builder = b
 
-	if err = clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil {
+	b, err := sqlbuilder.WithSession(clone.BaseDatabase, template)
+	if err != nil {
 		return nil, err
 	}
+	clone.Builder = b
 
 	return clone, nil
 }
diff --git a/postgresql/database.go b/postgresql/database.go
index 7920207e..ed4257bd 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -132,7 +132,10 @@ func (d *database) clone() (*database, error) {
 		return nil, err
 	}
 
-	clone.BaseDatabase = sqladapter.NewBaseDatabase(clone)
+	clone.BaseDatabase, err = d.BindClone(clone)
+	if err != nil {
+		return nil, err
+	}
 
 	b, err := sqlbuilder.WithSession(clone.BaseDatabase, template)
 	if err != nil {
@@ -140,9 +143,6 @@ func (d *database) clone() (*database, error) {
 	}
 	clone.Builder = b
 
-	if err = clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil {
-		return nil, err
-	}
 	return clone, nil
 }
 
diff --git a/ql/Makefile b/ql/Makefile
index 436aedb9..25e3c9a8 100644
--- a/ql/Makefile
+++ b/ql/Makefile
@@ -22,4 +22,4 @@ reset-db: require-client
 
 test: reset-db generate
 	#go test -tags generated -v -race # race: limit on 8192 simultaneously alive goroutines is exceeded, dying
-	go test -tags generated -v
+	go test -tags generated -timeout 30m -v
diff --git a/ql/database.go b/ql/database.go
index 3a5134cc..16116fb8 100644
--- a/ql/database.go
+++ b/ql/database.go
@@ -211,9 +211,16 @@ func (d *database) clone() (*database, error) {
 		return nil, err
 	}
 
-	if err := clone.open(); err != nil {
+	clone.BaseDatabase, err = d.BindClone(clone)
+	if err != nil {
+		return nil, err
+	}
+
+	b, err := sqlbuilder.WithSession(clone.BaseDatabase, template)
+	if err != nil {
 		return nil, err
 	}
+	clone.Builder = b
 
 	return clone, nil
 }
@@ -235,29 +242,25 @@ func (d *database) Err(err error) error {
 }
 
 // StatementExec wraps the statement to execute around a transaction.
-func (d *database) StatementExec(stmt *sql.Stmt, args ...interface{}) (sql.Result, error) {
-	if d.BaseDatabase.Transaction() == nil {
-		var tx *sql.Tx
-		var res sql.Result
-		var err error
-
-		if tx, err = d.Session().Begin(); err != nil {
-			return nil, err
-		}
-
-		s := tx.Stmt(stmt)
+func (d *database) StatementExec(query string, args ...interface{}) (res sql.Result, err error) {
+	if d.Transaction() != nil {
+		return d.Driver().(*sql.Tx).Exec(query, args...)
+	}
 
-		if res, err = s.Exec(args...); err != nil {
-			return nil, err
-		}
+	sqlTx, err := d.Session().Begin()
+	if err != nil {
+		return nil, err
+	}
 
-		if err = tx.Commit(); err != nil {
-			return nil, err
-		}
+	if res, err = sqlTx.Exec(query, args...); err != nil {
+		return nil, err
+	}
 
-		return res, err
+	if err = sqlTx.Commit(); err != nil {
+		return nil, err
 	}
-	return stmt.Exec(args...)
+
+	return res, err
 }
 
 // NewLocalCollection allows sqladapter create a local db.Collection.
diff --git a/sqlite/database.go b/sqlite/database.go
index b72569dc..074a764f 100644
--- a/sqlite/database.go
+++ b/sqlite/database.go
@@ -153,9 +153,16 @@ func (d *database) clone() (*database, error) {
 		return nil, err
 	}
 
-	if err := clone.open(); err != nil {
+	clone.BaseDatabase, err = d.BindClone(clone)
+	if err != nil {
+		return nil, err
+	}
+
+	b, err := sqlbuilder.WithSession(clone.BaseDatabase, template)
+	if err != nil {
 		return nil, err
 	}
+	clone.Builder = b
 
 	return clone, nil
 }
@@ -177,29 +184,28 @@ func (d *database) Err(err error) error {
 }
 
 // StatementExec wraps the statement to execute around a transaction.
-func (d *database) StatementExec(stmt *sql.Stmt, args ...interface{}) (sql.Result, error) {
-	if d.BaseDatabase.Transaction() == nil {
-		var tx *sql.Tx
-		var res sql.Result
-		var err error
+func (d *database) StatementExec(query string, args ...interface{}) (res sql.Result, err error) {
+	d.txMu.Lock()
+	defer d.txMu.Unlock()
 
-		if tx, err = d.Session().Begin(); err != nil {
-			return nil, err
-		}
-
-		s := tx.Stmt(stmt)
+	if d.Transaction() != nil {
+		return d.Driver().(*sql.Tx).Exec(query, args...)
+	}
 
-		if res, err = s.Exec(args...); err != nil {
-			return nil, err
-		}
+	sqlTx, err := d.Session().Begin()
+	if err != nil {
+		return nil, err
+	}
 
-		if err = tx.Commit(); err != nil {
-			return nil, err
-		}
+	if res, err = sqlTx.Exec(query, args...); err != nil {
+		return nil, err
+	}
 
-		return res, err
+	if err = sqlTx.Commit(); err != nil {
+		return nil, err
 	}
-	return stmt.Exec(args...)
+
+	return res, err
 }
 
 // NewLocalCollection allows sqladapter create a local db.Collection.
diff --git a/sqlite/tx.go b/sqlite/tx.go
index 19948754..c39e76e0 100644
--- a/sqlite/tx.go
+++ b/sqlite/tx.go
@@ -22,8 +22,8 @@
 package sqlite
 
 import (
-	"upper.io/db.v2"
 	"upper.io/db.v2/internal/sqladapter"
+	"upper.io/db.v2/lib/sqlbuilder"
 )
 
 type tx struct {
@@ -31,19 +31,5 @@ type tx struct {
 }
 
 var (
-	_ = db.Tx(&tx{})
+	_ = sqlbuilder.Tx(&tx{})
 )
-
-func (t *tx) Commit() error {
-	if sess := t.Session(); sess != nil {
-		defer sess.Close()
-	}
-	return t.DatabaseTx.Commit()
-}
-
-func (t *tx) Rollback() error {
-	if sess := t.Session(); sess != nil {
-		defer sess.Close()
-	}
-	return t.DatabaseTx.Rollback()
-}
-- 
GitLab