good morning!!!!

Skip to content
Snippets Groups Projects
Commit d5a46dfa authored by José Carlos Nieto's avatar José Carlos Nieto
Browse files

Testing different outcomes for transactions.

parent 834e8596
No related branches found
No related tags found
No related merge requests found
...@@ -88,6 +88,7 @@ type database struct { ...@@ -88,6 +88,7 @@ type database struct {
name string name string
sess *sql.DB sess *sql.DB
sessMu sync.Mutex
cachedStatements *cache.Cache cachedStatements *cache.Cache
cachedCollections *cache.Cache cachedCollections *cache.Cache
...@@ -106,7 +107,9 @@ func (d *database) Session() *sql.DB { ...@@ -106,7 +107,9 @@ func (d *database) Session() *sql.DB {
// BindTx binds a *sql.Tx into *database // BindTx binds a *sql.Tx into *database
func (d *database) BindTx(t *sql.Tx) error { func (d *database) BindTx(t *sql.Tx) error {
d.sessMu.Lock()
d.baseTx = newTx(t) d.baseTx = newTx(t)
defer d.sessMu.Unlock()
return d.Ping() return d.Ping()
} }
...@@ -130,7 +133,9 @@ func (d *database) Name() string { ...@@ -130,7 +133,9 @@ func (d *database) Name() string {
// BindSession binds a *sql.DB into *database // BindSession binds a *sql.DB into *database
func (d *database) BindSession(sess *sql.DB) error { func (d *database) BindSession(sess *sql.DB) error {
d.sessMu.Lock()
d.sess = sess d.sess = sess
d.sessMu.Unlock()
if err := d.Ping(); err != nil { if err := d.Ping(); err != nil {
return err return err
...@@ -154,11 +159,13 @@ func (d *database) Ping() error { ...@@ -154,11 +159,13 @@ func (d *database) Ping() error {
// Close terminates the current database session // Close terminates the current database session
func (d *database) Close() error { func (d *database) Close() error {
defer func() { defer func() {
d.sessMu.Lock()
d.sess = nil d.sess = nil
d.baseTx = nil d.baseTx = nil
d.sessMu.Unlock()
}() }()
if d.sess != nil { if d.sess != nil {
if d.Tx() != nil && !d.Tx().Commited() { if d.Tx() != nil && !d.Tx().Committed() {
d.Tx().Rollback() d.Tx().Rollback()
} }
d.cachedStatements.Clear() // Closes prepared statements as well. d.cachedStatements.Clear() // Closes prepared statements as well.
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"flag" "flag"
"log" "log"
"fmt"
"math/rand" "math/rand"
"os" "os"
"strconv" "strconv"
...@@ -929,7 +930,7 @@ func TestTransactionsAndRollback(t *testing.T) { ...@@ -929,7 +930,7 @@ func TestTransactionsAndRollback(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.Error(t, err, "Already commited") assert.Error(t, err, "Already committed")
// Let's verify we have 3 rows. // Let's verify we have 3 rows.
artist = sess.Collection("artist") artist = sess.Collection("artist")
...@@ -1083,38 +1084,103 @@ func TestExhaustConnectionPool(t *testing.T) { ...@@ -1083,38 +1084,103 @@ func TestExhaustConnectionPool(t *testing.T) {
return return
} }
var tMu sync.Mutex
tFatal := func(err error) {
tMu.Lock()
defer tMu.Unlock()
t.Fatal(err)
}
tLogf := func(format string, args... interface{}) {
tMu.Lock()
defer tMu.Unlock()
t.Logf(format, args...)
}
sess := mustOpen() sess := mustOpen()
defer sess.Close() defer sess.Close()
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 300; i++ { for i := 0; i < 300; i++ {
wg.Add(1) tLogf("Tx %d: Pending", i)
t.Logf("Tx %d: Pending", i)
go func(t *testing.T, wg *sync.WaitGroup, i int) { wg.Add(1)
var tx db.Tx go func(wg *sync.WaitGroup, i int) {
defer wg.Done() defer wg.Done()
start := time.Now()
// Requesting a new transaction session. // Requesting a new transaction session.
start := time.Now()
tx, err := sess.NewTransaction() tx, err := sess.NewTransaction()
if err != nil { if err != nil {
t.Fatal(err) tFatal(err)
} }
tLogf("Tx %d: OK (time to connect: %v)", i, time.Now().Sub(start))
t.Logf("Tx %d: OK (waiting time: %v)", i, time.Now().Sub(start))
// Let's suppose that we do a bunch of complex stuff and that the // Let's suppose that we do a bunch of complex stuff and that the
// transaction lasts 3 seconds. // transaction lasts 3 seconds.
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
switch i%7 {
case 0:
var account map[string]interface{}
if err := tx.Collection("artist").Find().One(&account); err != nil {
tFatal(err)
}
if err := tx.Commit(); err != nil {
tFatal(err)
}
tLogf("Tx %d: Committed", i)
case 1:
if _, err := tx.DeleteFrom("artist").Exec(); err != nil {
tFatal(err)
}
if err := tx.Rollback(); err != nil {
tFatal(err)
}
tLogf("Tx %d: Rolled back", i)
case 2:
if err := tx.Close(); err != nil { if err := tx.Close(); err != nil {
t.Fatal(err) tFatal(err)
} }
tLogf("Tx %d: Closed", i)
t.Logf("Tx %d: Done", i) case 3:
}(t, &wg, i) var account map[string]interface{}
if err := tx.Collection("artist").Find().One(&account); err != nil {
tFatal(err)
}
if err := tx.Commit(); err != nil {
tFatal(err)
}
if err := tx.Close(); err != nil {
tFatal(err)
}
tLogf("Tx %d: Committed and closed", i)
case 4:
if err := tx.Rollback(); err != nil {
tFatal(err)
}
if err := tx.Close(); err != nil {
tFatal(err)
}
tLogf("Tx %d: Rolled back and closed", i)
case 5:
if err := tx.Close(); err != nil {
tFatal(err)
}
if err := tx.Commit(); err == nil {
tFatal(fmt.Errorf("Error expected"))
}
tLogf("Tx %d: Closed and committed", i)
case 6:
if err := tx.Close(); err != nil {
tFatal(err)
}
if err := tx.Rollback(); err == nil {
tFatal(fmt.Errorf("Error expected"))
}
tLogf("Tx %d: Closed and rolled back", i)
}
}(&wg, i)
} }
wg.Wait() wg.Wait()
......
...@@ -36,7 +36,7 @@ type Tx interface { ...@@ -36,7 +36,7 @@ type Tx interface {
type BaseTx interface { type BaseTx interface {
Commit() error Commit() error
Rollback() error Rollback() error
Commited() bool Committed() bool
} }
type txWrapper struct { type txWrapper struct {
...@@ -61,16 +61,16 @@ func newTxWrapper(db Database) Tx { ...@@ -61,16 +61,16 @@ func newTxWrapper(db Database) Tx {
type sqlTx struct { type sqlTx struct {
*sql.Tx *sql.Tx
commited atomic.Value committed atomic.Value
} }
func newTx(tx *sql.Tx) BaseTx { func newTx(tx *sql.Tx) BaseTx {
return &sqlTx{Tx: tx} return &sqlTx{Tx: tx}
} }
func (t *sqlTx) Commited() bool { func (t *sqlTx) Committed() bool {
commited := t.commited.Load() committed := t.committed.Load()
if commited != nil { if committed != nil {
return true return true
} }
return false return false
...@@ -78,11 +78,21 @@ func (t *sqlTx) Commited() bool { ...@@ -78,11 +78,21 @@ func (t *sqlTx) Commited() bool {
func (t *sqlTx) Commit() (err error) { func (t *sqlTx) Commit() (err error) {
if err = t.Tx.Commit(); err == nil { if err = t.Tx.Commit(); err == nil {
t.commited.Store(struct{}{}) t.committed.Store(struct{}{})
} }
return err return err
} }
func (t *txWrapper) Commit() error {
defer t.Database.Close()
return t.BaseTx.Commit()
}
func (t *txWrapper) Rollback() error {
defer t.Database.Close()
return t.BaseTx.Rollback()
}
var ( var (
_ = BaseTx(&sqlTx{}) _ = BaseTx(&sqlTx{})
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment