diff --git a/postgresql/database.go b/postgresql/database.go index c2f3d774e096cf5018a09e6867b22cb4544af02a..80b0836a9100015da137146ff25de428acda01cb 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -22,11 +22,10 @@ package postgresql import ( - "strings" - "database/sql" + "strings" + "sync" - "fmt" _ "github.com/lib/pq" // PostgreSQL driver. "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter" @@ -39,8 +38,10 @@ type Database interface { db.Database builder.Builder + UseTransaction(tx *sql.Tx) (Tx, error) NewTransaction() (Tx, error) - With(interface{}) (Database, error) + + Transaction(fn func(tx Tx) error) error } // database is the actual implementation of Database @@ -48,6 +49,9 @@ type database struct { sqladapter.BaseDatabase // Leveraged by sqladapter builder.Builder + txMu sync.Mutex + tx sqladapter.Tx + connURL db.ConnectionURL } @@ -64,35 +68,36 @@ func newDatabase(settings db.ConnectionURL) (*database, error) { return d, nil } -func (d *database) With(sess interface{}) (Database, error) { - clone, err := newDatabase(d.connURL) +// Open stablishes a new connection to a SQL server. +func Open(settings db.ConnectionURL) (Database, error) { + d, err := newDatabase(settings) if err != nil { return nil, err } + if err := d.Open(settings); err != nil { + return nil, err + } + return d, nil +} - switch t := sess.(type) { - case *sql.DB: - if err := clone.BindSession(t); err != nil { - return nil, err - } - case *sql.Tx: - if err := clone.BindTx(t); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("Unknown session type %T", t) +// New wraps the given *sql.DB session and creates a new db session. +func New(sess *sql.DB) (Database, error) { + d, err := newDatabase(nil) + if err != nil { + return nil, err } - return clone, nil -} + // Binding with sqladapter's logic. + d.BaseDatabase = sqladapter.NewBaseDatabase(d) -// Open stablishes a new connection to a SQL server. -func Open(settings db.ConnectionURL) (Database, error) { - d, err := newDatabase(settings) + // Binding with builder. + b, err := builder.New(d.BaseDatabase, template) if err != nil { return nil, err } - if err := d.Open(settings); err != nil { + d.Builder = b + + if err := d.BaseDatabase.BindSession(sess); err != nil { return nil, err } return d, nil @@ -112,6 +117,29 @@ func (d *database) Open(connURL db.ConnectionURL) error { return d.open() } +// UseTransaction makes the adapter use the given transaction. +func (d *database) UseTransaction(sqlTx *sql.Tx) (Tx, error) { + if sqlTx == nil { // No transaction given. + d.txMu.Lock() + currentTx := d.tx + d.txMu.Unlock() + if currentTx != nil { + return &tx{Tx: currentTx}, nil + } + // Create a new transaction. + return d.NewTransaction() + } + + d.txMu.Lock() + defer d.txMu.Unlock() + + if err := d.BindTx(sqlTx); err != nil { + return nil, err + } + d.tx = sqladapter.NewTx(d) + return &tx{Tx: d.tx}, nil +} + // NewTransaction starts a transaction block. func (d *database) NewTransaction() (Tx, error) { nTx, err := d.NewLocalTransaction() @@ -202,6 +230,21 @@ func (d *database) NewLocalCollection(name string) db.Collection { return newTable(d, name) } +// Transaction creates a transaction and passes it to the given function, if +// if the function returns no error then the transaction is commited. +func (d *database) Transaction(fn func(tx Tx) error) error { + tx, err := d.NewTransaction() + if err != nil { + return err + } + defer tx.Close() + if err := fn(tx); err != nil { + tx.Rollback() + return err + } + return tx.Commit() +} + // NewLocalTransaction allows sqladapter start a transaction block. func (d *database) NewLocalTransaction() (sqladapter.Tx, error) { clone, err := d.clone() @@ -209,6 +252,9 @@ func (d *database) NewLocalTransaction() (sqladapter.Tx, error) { return nil, err } + clone.txMu.Lock() + defer clone.txMu.Unlock() + connFn := func() error { sqlTx, err := clone.BaseDatabase.Session().Begin() if err == nil { @@ -221,6 +267,8 @@ func (d *database) NewLocalTransaction() (sqladapter.Tx, error) { return nil, err } + clone.tx = sqladapter.NewTx(clone) + return sqladapter.NewTx(clone), nil } diff --git a/postgresql/tx.go b/postgresql/tx.go index 4a5a70b01aaf2f86f5254887d207d6a0075e7a2a..1e4f1c4462cc6d6633ce240d7586d6c26599dc3c 100644 --- a/postgresql/tx.go +++ b/postgresql/tx.go @@ -22,6 +22,7 @@ package postgresql import ( + "database/sql" "fmt" "upper.io/db.v2" @@ -48,6 +49,18 @@ func (t *tx) NewTransaction() (Tx, error) { return t, db.ErrAlreadyWithinTransaction } +func (t *tx) UseTransaction(sqlTx *sql.Tx) (Tx, error) { + return t, db.ErrAlreadyWithinTransaction +} + func (t *tx) With(interface{}) (Database, error) { return nil, fmt.Errorf("Not implemented.") } + +func (t *tx) Transaction(fn func(tx Tx) error) error { + if err := fn(t); err != nil { + t.Rollback() + return err + } + return t.Commit() +}