diff --git a/db.go b/db.go index 2699c229924c5499c0af980cd357f7f517a243b4..ce292797e0418897146b731baa96ffa0505fe891 100644 --- a/db.go +++ b/db.go @@ -523,6 +523,10 @@ type ConnectionURL interface { // String returns the connection string that is going to be passed to the // adapter. String() string + + // Adapter returns the name of the adapter associated with the connection + // URL. + Adapter() string } // EnvEnableDebug can be used by adapters to determine if the user has enabled diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go index 8c68192c62c4e12d5c75d40a393b3b16ef52dff1..f47e474b7db76c86b6beb26d4361066632e2f1c7 100644 --- a/internal/sqladapter/collection.go +++ b/internal/sqladapter/collection.go @@ -71,10 +71,10 @@ func (c *collection) InsertReturning(item interface{}) error { return fmt.Errorf("Expecting a pointer to map or string but got %T", item) } - var tx Tx + var tx DatabaseTx inTx := false - if currTx := c.p.Database().Tx(); currTx != nil { + if currTx := c.p.Database().Transaction(); currTx != nil { tx = newTxWrapper(c.p.Database()) inTx = true } else { diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index f6f4750b5d1698794593c1a3bb687a0573a45a03..6ab98319b57309065d842aa650dae1b6c6ff6bf5 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -46,7 +46,7 @@ type PartialDatabase interface { ConnectionURL() db.ConnectionURL Err(in error) (out error) - NewLocalTransaction() (Tx, error) + NewLocalTransaction() (DatabaseTx, error) } // BaseDatabase defines the methods provided by sqladapter that do not have to @@ -65,7 +65,7 @@ type BaseDatabase interface { Session() *sql.DB BindTx(*sql.Tx) error - Tx() BaseTx + Transaction() BaseTx } // NewBaseDatabase provides a BaseDatabase given a PartialDatabase @@ -116,7 +116,7 @@ func (d *database) BindTx(t *sql.Tx) error { // Tx returns a BaseTx, which, if not nil, means that this session is within a // transaction -func (d *database) Tx() BaseTx { +func (d *database) Transaction() BaseTx { return d.baseTx } @@ -180,8 +180,8 @@ func (d *database) Close() error { d.sessMu.Unlock() }() if d.sess != nil { - if d.Tx() != nil && !d.Tx().Committed() { - d.Tx().Rollback() + if d.Transaction() != nil && !d.Transaction().Committed() { + d.Transaction().Rollback() } d.cachedCollections.Clear() d.cachedStatements.Clear() // Closes prepared statements as well. @@ -288,7 +288,7 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) // Driver returns the underlying *sql.DB or *sql.Tx instance. func (d *database) Driver() interface{} { - if tx := d.Tx(); tx != nil { + if tx := d.Transaction(); tx != nil { // A transaction return tx.(*sqlTx).Tx } @@ -299,7 +299,7 @@ func (d *database) Driver() interface{} { // *sql.Stmt. This method will attempt to used a cached prepared statement, if // available. func (d *database) prepareStatement(stmt *exql.Statement) (*sql.Stmt, string, error) { - if d.sess == nil { + if d.sess == nil && d.Transaction() == nil { return nil, "", db.ErrNotConnected } @@ -317,8 +317,8 @@ func (d *database) prepareStatement(stmt *exql.Statement) (*sql.Stmt, string, er var p *sql.Stmt var err error - if d.Tx() != nil { - p, err = d.Tx().(*sqlTx).Prepare(query) + if d.Transaction() != nil { + p, err = d.Transaction().(*sqlTx).Prepare(query) } else { p, err = d.sess.Prepare(query) } diff --git a/internal/sqladapter/tx.go b/internal/sqladapter/tx.go index 41973866a8f7305a74b9becd7fca33b010e50301..62d1d9b1eeba279293ad7bebec6a70f1f256a99e 100644 --- a/internal/sqladapter/tx.go +++ b/internal/sqladapter/tx.go @@ -27,7 +27,7 @@ import ( ) // Tx represents a database session within a transaction. -type Tx interface { +type DatabaseTx interface { Database BaseTx } @@ -45,17 +45,17 @@ type txWrapper struct { } // NewTx creates a database session within a transaction. -func NewTx(db Database) Tx { +func NewTx(db Database) DatabaseTx { return &txWrapper{ Database: db, - BaseTx: db.Tx(), + BaseTx: db.Transaction(), } } -func newTxWrapper(db Database) Tx { +func newTxWrapper(db Database) DatabaseTx { return &txWrapper{ Database: db, - BaseTx: db.Tx(), + BaseTx: db.Transaction(), } } @@ -84,12 +84,10 @@ func (t *sqlTx) Commit() (err error) { } func (t *txWrapper) Commit() error { - defer t.Database.Close() return t.BaseTx.Commit() } func (t *txWrapper) Rollback() error { - defer t.Database.Close() return t.BaseTx.Rollback() } diff --git a/lib/interfaces.go b/lib/interfaces.go new file mode 100644 index 0000000000000000000000000000000000000000..92654de8e036a348c3cc01b01d5fd188484a47db --- /dev/null +++ b/lib/interfaces.go @@ -0,0 +1,94 @@ +package lib + +import ( + "database/sql" + "database/sql/driver" + "fmt" + + "upper.io/db.v2" + "upper.io/db.v2/sqlbuilder" +) + +// SQLDatabase represents a SQL database. +type SQLDatabase interface { + db.Database + builder.Builder + + NewTx() (SQLTx, error) + Tx(fn func(tx SQLTx) error) error +} + +type Tx interface { + Commit() error + Rollback() error +} + +// Tx represents a transaction. +type SQLTx interface { + SQLDatabase + Tx +} + +type SQLAdapter struct { + New func(sqlDB *sql.DB) (SQLDatabase, error) + NewTx func(sqlTx *sql.Tx) (SQLTx, error) + Open func(settings db.ConnectionURL) (SQLDatabase, error) +} + +var adapters map[string]*SQLAdapter + +func init() { + adapters = make(map[string]*SQLAdapter) +} + +func RegisterSQLAdapter(name string, fn *SQLAdapter) { + if _, ok := adapters[name]; ok { + panic(fmt.Errorf("upper: Adapter %q was already registered", name)) + } + adapters[name] = fn +} + +func Adapter(name string) SQLAdapter { + if fn, ok := adapters[name]; ok { + return *fn + } + return missingAdapter(name) +} + +func missingAdapter(name string) SQLAdapter { + err := fmt.Errorf("upper: Missing adapter %q, forgot to import?", name) + return SQLAdapter{ + New: func(*sql.DB) (SQLDatabase, error) { + return nil, err + }, + NewTx: func(*sql.Tx) (SQLTx, error) { + return nil, err + }, + Open: func(db.ConnectionURL) (SQLDatabase, error) { + return nil, err + }, + } +} + +type SQLTransaction interface { + SQLDriver + + Commit() error + Rollback() error +} + +type SQLDriver interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +type SQLSession interface { + SQLDriver + + Begin() (*sql.Tx, error) + Close() error + Driver() driver.Driver + Ping() error +} diff --git a/postgresql/connection.go b/postgresql/connection.go index 07c0fb54e8bf820545402e776b4a10a9aae5eae7..041a3050a073ce78e0ef6ec5d2e1933ec7fb59fa 100644 --- a/postgresql/connection.go +++ b/postgresql/connection.go @@ -90,6 +90,10 @@ type ConnectionURL struct { var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) +func (c ConnectionURL) Adapter() string { + return Adapter +} + func (c ConnectionURL) String() (s string) { u := make([]string, 0, 6) diff --git a/postgresql/database.go b/postgresql/database.go index 80b0836a9100015da137146ff25de428acda01cb..971d9db4ece96b38f3312776cca742277f4a83ee 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -29,35 +29,26 @@ import ( _ "github.com/lib/pq" // PostgreSQL driver. "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter" + "upper.io/db.v2/lib" "upper.io/db.v2/sqlbuilder" "upper.io/db.v2/sqlbuilder/exql" ) -// Database represents a SQL database. -type Database interface { - db.Database - builder.Builder - - UseTransaction(tx *sql.Tx) (Tx, error) - NewTransaction() (Tx, error) - - Transaction(fn func(tx Tx) error) error -} - // database is the actual implementation of Database type database struct { sqladapter.BaseDatabase // Leveraged by sqladapter builder.Builder - txMu sync.Mutex - tx sqladapter.Tx + txMu sync.Mutex + tx sqladapter.DatabaseTx + cloned bool connURL db.ConnectionURL } var ( _ = sqladapter.Database(&database{}) - _ = db.Database(&database{}) + _ = lib.SQLDatabase(&database{}) ) // newDatabase binds *database with sqladapter and the SQL builer. @@ -69,7 +60,7 @@ func newDatabase(settings db.ConnectionURL) (*database, error) { } // Open stablishes a new connection to a SQL server. -func Open(settings db.ConnectionURL) (Database, error) { +func Open(settings db.ConnectionURL) (lib.SQLDatabase, error) { d, err := newDatabase(settings) if err != nil { return nil, err @@ -80,8 +71,33 @@ func Open(settings db.ConnectionURL) (Database, error) { return d, nil } +func NewTx(sqlTx *sql.Tx) (lib.SQLTx, error) { + d, err := newDatabase(nil) + if err != nil { + return nil, err + } + + // Binding with sqladapter's logic. + d.BaseDatabase = sqladapter.NewBaseDatabase(d) + + // Binding with builder. + b, err := builder.New(d.BaseDatabase, template) + if err != nil { + return nil, err + } + d.Builder = b + + if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + return nil, err + } + + d.tx = sqladapter.NewTx(d) + + return &tx{DatabaseTx: d.tx}, nil +} + // New wraps the given *sql.DB session and creates a new db session. -func New(sess *sql.DB) (Database, error) { +func New(sess *sql.DB) (lib.SQLDatabase, error) { d, err := newDatabase(nil) if err != nil { return nil, err @@ -117,17 +133,16 @@ func (d *database) Open(connURL db.ConnectionURL) error { return d.open() } -// UseTransaction makes the adapter use the given transaction. -func (d *database) UseTransaction(sqlTx *sql.Tx) (Tx, error) { +func (d *database) UseTx(sqlTx *sql.Tx) (lib.SQLTx, error) { if sqlTx == nil { // No transaction given. d.txMu.Lock() currentTx := d.tx d.txMu.Unlock() if currentTx != nil { - return &tx{Tx: currentTx}, nil + return &tx{DatabaseTx: currentTx}, nil } // Create a new transaction. - return d.NewTransaction() + return d.NewTx() } d.txMu.Lock() @@ -137,16 +152,16 @@ func (d *database) UseTransaction(sqlTx *sql.Tx) (Tx, error) { return nil, err } d.tx = sqladapter.NewTx(d) - return &tx{Tx: d.tx}, nil + return &tx{DatabaseTx: d.tx}, nil } -// NewTransaction starts a transaction block. -func (d *database) NewTransaction() (Tx, error) { +// NewTx starts a transaction block. +func (d *database) NewTx() (lib.SQLTx, error) { nTx, err := d.NewLocalTransaction() if err != nil { return nil, err } - return &tx{Tx: nTx}, nil + return &tx{DatabaseTx: nTx}, nil } // Collections returns a list of non-system tables from the database. @@ -200,6 +215,7 @@ func (d *database) clone() (*database, error) { if err != nil { return nil, err } + clone.cloned = true if err := clone.open(); err != nil { return nil, err @@ -232,8 +248,8 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Transaction creates a transaction and passes it to the given function, if // if the function returns no error then the transaction is commited. -func (d *database) Transaction(fn func(tx Tx) error) error { - tx, err := d.NewTransaction() +func (d *database) Tx(fn func(tx lib.SQLTx) error) error { + tx, err := d.NewTx() if err != nil { return err } @@ -246,7 +262,7 @@ func (d *database) Transaction(fn func(tx Tx) error) error { } // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.Tx, error) { +func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index 27645ded0711d90bff6cc9e5493103e00a3c48b1..d80fc240989b42f8c88a854d0cbc20beea449ffc 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -23,6 +23,7 @@ package postgresql // import "upper.io/db.v2/postgresql" import ( "upper.io/db.v2" + "upper.io/db.v2/lib" ) const sqlDriver = `postgres` @@ -33,3 +34,11 @@ const Adapter = `postgresql` func init() { db.Register(Adapter, &database{}) } + +func init() { + lib.RegisterSQLAdapter(Adapter, &lib.SQLAdapter{ + New: New, + NewTx: NewTx, + Open: Open, + }) +} diff --git a/postgresql/tx.go b/postgresql/tx.go index 1e4f1c4462cc6d6633ce240d7586d6c26599dc3c..2c89477548101c958c7337a8d23c15ac0c3c0cbc 100644 --- a/postgresql/tx.go +++ b/postgresql/tx.go @@ -22,45 +22,23 @@ package postgresql import ( - "database/sql" - "fmt" - "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter" + "upper.io/db.v2/lib" ) -// Tx represents a transaction. -type Tx interface { - Database - - Commit() error - Rollback() error -} - type tx struct { - sqladapter.Tx + sqladapter.DatabaseTx } var ( - _ = db.Tx(&tx{}) + _ = lib.SQLTx(&tx{}) ) -func (t *tx) NewTransaction() (Tx, error) { +func (t *tx) NewTx() (lib.SQLTx, error) { return t, db.ErrAlreadyWithinTransaction } -func (t *tx) UseTransaction(sqlTx *sql.Tx) (Tx, error) { - return t, db.ErrAlreadyWithinTransaction -} - -func (t *tx) With(interface{}) (Database, error) { - return nil, fmt.Errorf("Not implemented.") -} - -func (t *tx) Transaction(fn func(tx Tx) error) error { - if err := fn(t); err != nil { - t.Rollback() - return err - } - return t.Commit() +func (t *tx) Tx(fn func(tx lib.SQLTx) error) error { + return fn(t) }