From c395201d56bc75c037c8164bab7df47f2ba1fbb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Fri, 8 Jul 2016 09:39:46 -0500 Subject: [PATCH] Publish builder interfaces and add SQLDatabase and SQLTx. --- sqlbuilder/interfaces.go => builder.go | 2 +- db.go | 29 +++++++- internal/sqladapter/database.go | 3 +- internal/sqladapter/result.go | 9 ++- internal/sqladapter/tx.go | 5 +- lib/interfaces.go | 94 -------------------------- postgresql/database.go | 38 ++--------- postgresql/postgresql.go | 7 +- postgresql/tx.go | 11 +-- sqlbuilder/builder.go | 20 +++--- sqlbuilder/delete.go | 5 +- sqlbuilder/insert.go | 9 +-- sqlbuilder/select.go | 34 +++++----- sqlbuilder/update.go | 7 +- wrapper.go | 65 +++++++++--------- 15 files changed, 117 insertions(+), 221 deletions(-) rename sqlbuilder/interfaces.go => builder.go (99%) delete mode 100644 lib/interfaces.go diff --git a/sqlbuilder/interfaces.go b/builder.go similarity index 99% rename from sqlbuilder/interfaces.go rename to builder.go index ace702d3..61d0cbf2 100644 --- a/sqlbuilder/interfaces.go +++ b/builder.go @@ -21,7 +21,7 @@ // Package builder provides tools to compose, execute and map SQL queries to Go // structs and maps. -package builder // import "upper.io/db.v2/sqlbuilder" +package db import ( "database/sql" diff --git a/db.go b/db.go index ce292797..93b02420 100644 --- a/db.go +++ b/db.go @@ -413,8 +413,6 @@ type Database interface { // // err = tx.Commit() type Tx interface { - Database - // Rollback discards all the instructions on the current transaction. Rollback() error @@ -525,10 +523,35 @@ type ConnectionURL interface { String() string // Adapter returns the name of the adapter associated with the connection - // URL. + // URL, this name can be used as argument by the db.Adapter function to + // retrieve an imported adapter. Adapter() string } +// SQLDatabase represents a SQL database capable of creating transactions and +// use builder methods. +type SQLDatabase interface { + Database + Builder + + // NewTx returns a SQLTx in case the database supports transactions and the + // transaction can be initialized. + NewTx() (SQLTx, error) + + // Tx accepts a function which first argument is a SQLDatabase which runs + // within a transaction. If the fn function returns nil, the transaction is + // commited and the Tx function returns nil too, if fn returns an error, then + // the transaction is rolled back and the error is returned by Tx. + Tx(fn func(sess SQLTx) error) error +} + +// SQLTx represents a transaction on a SQL database. +type SQLTx interface { + Database + Builder + Tx +} + // EnvEnableDebug can be used by adapters to determine if the user has enabled // debugging. // diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 6ab98319..3e0c6c6d 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -8,7 +8,6 @@ import ( "upper.io/db.v2" "upper.io/db.v2/internal/logger" - "upper.io/db.v2/sqlbuilder" "upper.io/db.v2/sqlbuilder/cache" "upper.io/db.v2/sqlbuilder/exql" ) @@ -31,7 +30,7 @@ type Database interface { // PartialDatabase defines all the methods an adapter must provide. type PartialDatabase interface { - builder.Builder + db.Builder Collections() ([]string, error) Open(db.ConnectionURL) error diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go index 0775ff34..8bddc76b 100644 --- a/internal/sqladapter/result.go +++ b/internal/sqladapter/result.go @@ -25,14 +25,13 @@ import ( "sync" "upper.io/db.v2" - "upper.io/db.v2/sqlbuilder" ) // Result represents a delimited set of items bound by a condition. type Result struct { - b builder.Builder + b db.Builder table string - iter builder.Iterator + iter db.Iterator limit int offset int fields []interface{} @@ -51,7 +50,7 @@ func filter(conds []interface{}) []interface{} { // NewResult creates and Results a new Result set on the given table, this set // is limited by the given exql.Where conditions. -func NewResult(b builder.Builder, table string, conds []interface{}) *Result { +func NewResult(b db.Builder, table string, conds []interface{}) *Result { return &Result{ b: b, table: table, @@ -203,7 +202,7 @@ func (r *Result) Count() (uint64, error) { return counter.Count, nil } -func (r *Result) buildSelect() builder.Selector { +func (r *Result) buildSelect() db.Selector { q := r.b.Select(r.fields...) q.From(r.table) diff --git a/internal/sqladapter/tx.go b/internal/sqladapter/tx.go index 62d1d9b1..bd73ab3c 100644 --- a/internal/sqladapter/tx.go +++ b/internal/sqladapter/tx.go @@ -24,6 +24,7 @@ package sqladapter import ( "database/sql" "sync/atomic" + "upper.io/db.v2" ) // Tx represents a database session within a transaction. @@ -34,8 +35,8 @@ type DatabaseTx interface { // BaseTx defines methods to be implemented by a transaction. type BaseTx interface { - Commit() error - Rollback() error + db.Tx + Committed() bool } diff --git a/lib/interfaces.go b/lib/interfaces.go deleted file mode 100644 index 92654de8..00000000 --- a/lib/interfaces.go +++ /dev/null @@ -1,94 +0,0 @@ -package lib - -import ( - "database/sql" - "database/sql/driver" - "fmt" - - "upper.io/db.v2" - "upper.io/db.v2/sqlbuilder" -) - -// SQLDatabase represents a SQL database. -type SQLDatabase interface { - db.Database - builder.Builder - - NewTx() (SQLTx, error) - Tx(fn func(tx SQLTx) error) error -} - -type Tx interface { - Commit() error - Rollback() error -} - -// Tx represents a transaction. -type SQLTx interface { - SQLDatabase - Tx -} - -type SQLAdapter struct { - New func(sqlDB *sql.DB) (SQLDatabase, error) - NewTx func(sqlTx *sql.Tx) (SQLTx, error) - Open func(settings db.ConnectionURL) (SQLDatabase, error) -} - -var adapters map[string]*SQLAdapter - -func init() { - adapters = make(map[string]*SQLAdapter) -} - -func RegisterSQLAdapter(name string, fn *SQLAdapter) { - if _, ok := adapters[name]; ok { - panic(fmt.Errorf("upper: Adapter %q was already registered", name)) - } - adapters[name] = fn -} - -func Adapter(name string) SQLAdapter { - if fn, ok := adapters[name]; ok { - return *fn - } - return missingAdapter(name) -} - -func missingAdapter(name string) SQLAdapter { - err := fmt.Errorf("upper: Missing adapter %q, forgot to import?", name) - return SQLAdapter{ - New: func(*sql.DB) (SQLDatabase, error) { - return nil, err - }, - NewTx: func(*sql.Tx) (SQLTx, error) { - return nil, err - }, - Open: func(db.ConnectionURL) (SQLDatabase, error) { - return nil, err - }, - } -} - -type SQLTransaction interface { - SQLDriver - - Commit() error - Rollback() error -} - -type SQLDriver interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type SQLSession interface { - SQLDriver - - Begin() (*sql.Tx, error) - Close() error - Driver() driver.Driver - Ping() error -} diff --git a/postgresql/database.go b/postgresql/database.go index 971d9db4..b55eac9d 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -29,7 +29,6 @@ import ( _ "github.com/lib/pq" // PostgreSQL driver. "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter" - "upper.io/db.v2/lib" "upper.io/db.v2/sqlbuilder" "upper.io/db.v2/sqlbuilder/exql" ) @@ -37,7 +36,7 @@ import ( // database is the actual implementation of Database type database struct { sqladapter.BaseDatabase // Leveraged by sqladapter - builder.Builder + db.Builder txMu sync.Mutex tx sqladapter.DatabaseTx @@ -47,8 +46,7 @@ type database struct { } var ( - _ = sqladapter.Database(&database{}) - _ = lib.SQLDatabase(&database{}) + _ = db.SQLDatabase(&database{}) ) // newDatabase binds *database with sqladapter and the SQL builer. @@ -60,7 +58,7 @@ func newDatabase(settings db.ConnectionURL) (*database, error) { } // Open stablishes a new connection to a SQL server. -func Open(settings db.ConnectionURL) (lib.SQLDatabase, error) { +func Open(settings db.ConnectionURL) (db.SQLDatabase, error) { d, err := newDatabase(settings) if err != nil { return nil, err @@ -71,7 +69,7 @@ func Open(settings db.ConnectionURL) (lib.SQLDatabase, error) { return d, nil } -func NewTx(sqlTx *sql.Tx) (lib.SQLTx, error) { +func NewTx(sqlTx *sql.Tx) (db.SQLTx, error) { d, err := newDatabase(nil) if err != nil { return nil, err @@ -97,7 +95,7 @@ func NewTx(sqlTx *sql.Tx) (lib.SQLTx, error) { } // New wraps the given *sql.DB session and creates a new db session. -func New(sess *sql.DB) (lib.SQLDatabase, error) { +func New(sess *sql.DB) (db.SQLDatabase, error) { d, err := newDatabase(nil) if err != nil { return nil, err @@ -133,30 +131,8 @@ func (d *database) Open(connURL db.ConnectionURL) error { return d.open() } -func (d *database) UseTx(sqlTx *sql.Tx) (lib.SQLTx, error) { - if sqlTx == nil { // No transaction given. - d.txMu.Lock() - currentTx := d.tx - d.txMu.Unlock() - if currentTx != nil { - return &tx{DatabaseTx: currentTx}, nil - } - // Create a new transaction. - return d.NewTx() - } - - d.txMu.Lock() - defer d.txMu.Unlock() - - if err := d.BindTx(sqlTx); err != nil { - return nil, err - } - d.tx = sqladapter.NewTx(d) - return &tx{DatabaseTx: d.tx}, nil -} - // NewTx starts a transaction block. -func (d *database) NewTx() (lib.SQLTx, error) { +func (d *database) NewTx() (db.SQLTx, error) { nTx, err := d.NewLocalTransaction() if err != nil { return nil, err @@ -248,7 +224,7 @@ func (d *database) NewLocalCollection(name string) db.Collection { // Transaction creates a transaction and passes it to the given function, if // if the function returns no error then the transaction is commited. -func (d *database) Tx(fn func(tx lib.SQLTx) error) error { +func (d *database) Tx(fn func(tx db.SQLTx) error) error { tx, err := d.NewTx() if err != nil { return err diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index d80fc240..c4ec3519 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -23,7 +23,6 @@ package postgresql // import "upper.io/db.v2/postgresql" import ( "upper.io/db.v2" - "upper.io/db.v2/lib" ) const sqlDriver = `postgres` @@ -32,11 +31,7 @@ const sqlDriver = `postgres` const Adapter = `postgresql` func init() { - db.Register(Adapter, &database{}) -} - -func init() { - lib.RegisterSQLAdapter(Adapter, &lib.SQLAdapter{ + db.RegisterSQLAdapter(Adapter, &db.SQLAdapter{ New: New, NewTx: NewTx, Open: Open, diff --git a/postgresql/tx.go b/postgresql/tx.go index 2c894775..7f116aeb 100644 --- a/postgresql/tx.go +++ b/postgresql/tx.go @@ -24,7 +24,6 @@ package postgresql import ( "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter" - "upper.io/db.v2/lib" ) type tx struct { @@ -32,13 +31,5 @@ type tx struct { } var ( - _ = lib.SQLTx(&tx{}) + _ = db.SQLTx(&tx{}) ) - -func (t *tx) NewTx() (lib.SQLTx, error) { - return t, db.ErrAlreadyWithinTransaction -} - -func (t *tx) Tx(fn func(tx lib.SQLTx) error) error { - return fn(t) -} diff --git a/sqlbuilder/builder.go b/sqlbuilder/builder.go index 408a00f7..aad21449 100644 --- a/sqlbuilder/builder.go +++ b/sqlbuilder/builder.go @@ -55,7 +55,7 @@ type sqlBuilder struct { } // New returns a query builder that is bound to the given database session. -func New(sess interface{}, t *exql.Template) (Builder, error) { +func New(sess interface{}, t *exql.Template) (db.Builder, error) { switch v := sess.(type) { case *sql.DB: sess = newSqlgenProxy(v, t) @@ -72,18 +72,18 @@ func New(sess interface{}, t *exql.Template) (Builder, error) { } // NewBuilderWithTemplate returns a builder that is based on the given template. -func NewBuilderWithTemplate(t *exql.Template) Builder { +func NewBuilderWithTemplate(t *exql.Template) db.Builder { return &sqlBuilder{ t: newTemplateWithUtils(t), } } // NewIterator creates an iterator using the given *sql.Rows. -func NewIterator(rows *sql.Rows) Iterator { +func NewIterator(rows *sql.Rows) db.Iterator { return &iterator{rows, nil} } -func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) Iterator { +func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) db.Iterator { rows, err := b.Query(query, args...) return &iterator{rows, err} } @@ -121,7 +121,7 @@ func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, } } -func (b *sqlBuilder) SelectFrom(table string) Selector { +func (b *sqlBuilder) SelectFrom(table string) db.Selector { qs := &selector{ builder: b, table: table, @@ -131,7 +131,7 @@ func (b *sqlBuilder) SelectFrom(table string) Selector { return qs } -func (b *sqlBuilder) Select(columns ...interface{}) Selector { +func (b *sqlBuilder) Select(columns ...interface{}) db.Selector { qs := &selector{ builder: b, } @@ -140,7 +140,7 @@ func (b *sqlBuilder) Select(columns ...interface{}) Selector { return qs.Columns(columns...) } -func (b *sqlBuilder) InsertInto(table string) Inserter { +func (b *sqlBuilder) InsertInto(table string) db.Inserter { qi := &inserter{ builder: b, table: table, @@ -150,7 +150,7 @@ func (b *sqlBuilder) InsertInto(table string) Inserter { return qi } -func (b *sqlBuilder) DeleteFrom(table string) Deleter { +func (b *sqlBuilder) DeleteFrom(table string) db.Deleter { qd := &deleter{ builder: b, table: table, @@ -160,7 +160,7 @@ func (b *sqlBuilder) DeleteFrom(table string) Deleter { return qd } -func (b *sqlBuilder) Update(table string) Updater { +func (b *sqlBuilder) Update(table string) db.Updater { qu := &updater{ builder: b, table: table, @@ -436,6 +436,6 @@ func (p *exprProxy) StatementQueryRow(stmt *exql.Statement, args ...interface{}) } var ( - _ = Builder(&sqlBuilder{}) + _ = db.Builder(&sqlBuilder{}) _ = exprDB(&exprProxy{}) ) diff --git a/sqlbuilder/delete.go b/sqlbuilder/delete.go index 113d95e2..ec8722ad 100644 --- a/sqlbuilder/delete.go +++ b/sqlbuilder/delete.go @@ -3,6 +3,7 @@ package builder import ( "database/sql" + "upper.io/db.v2" "upper.io/db.v2/sqlbuilder/exql" ) @@ -15,14 +16,14 @@ type deleter struct { arguments []interface{} } -func (qd *deleter) Where(terms ...interface{}) Deleter { +func (qd *deleter) Where(terms ...interface{}) db.Deleter { where, arguments := qd.builder.t.ToWhereWithArguments(terms) qd.where = &where qd.arguments = append(qd.arguments, arguments...) return qd } -func (qd *deleter) Limit(limit int) Deleter { +func (qd *deleter) Limit(limit int) db.Deleter { qd.limit = limit return qd } diff --git a/sqlbuilder/insert.go b/sqlbuilder/insert.go index a2529fe1..5cb94530 100644 --- a/sqlbuilder/insert.go +++ b/sqlbuilder/insert.go @@ -3,6 +3,7 @@ package builder import ( "database/sql" + "upper.io/db.v2" "upper.io/db.v2/sqlbuilder/exql" ) @@ -27,7 +28,7 @@ func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) e return nil } -func (qi *inserter) Returning(columns ...string) Inserter { +func (qi *inserter) Returning(columns ...string) db.Inserter { qi.columnsToFragments(&qi.returning, columns) return qi } @@ -44,17 +45,17 @@ func (qi *inserter) QueryRow() (*sql.Row, error) { return qi.builder.sess.StatementQueryRow(qi.statement(), qi.arguments...) } -func (qi *inserter) Iterator() Iterator { +func (qi *inserter) Iterator() db.Iterator { rows, err := qi.builder.sess.StatementQuery(qi.statement(), qi.arguments...) return &iterator{rows, err} } -func (qi *inserter) Columns(columns ...string) Inserter { +func (qi *inserter) Columns(columns ...string) db.Inserter { qi.columnsToFragments(&qi.columns, columns) return qi } -func (qi *inserter) Values(values ...interface{}) Inserter { +func (qi *inserter) Values(values ...interface{}) db.Inserter { if len(qi.columns) == 0 && len(values) == 1 { ff, vv, _ := Map(values[0]) diff --git a/sqlbuilder/select.go b/sqlbuilder/select.go index 846d6fa6..e58480b6 100644 --- a/sqlbuilder/select.go +++ b/sqlbuilder/select.go @@ -33,12 +33,12 @@ type selector struct { err error } -func (qs *selector) From(tables ...string) Selector { +func (qs *selector) From(tables ...string) db.Selector { qs.table = strings.Join(tables, ",") return qs } -func (qs *selector) Columns(columns ...interface{}) Selector { +func (qs *selector) Columns(columns ...interface{}) db.Selector { f, err := columnFragments(qs.builder.t, columns) if err != nil { qs.err = err @@ -48,19 +48,19 @@ func (qs *selector) Columns(columns ...interface{}) Selector { return qs } -func (qs *selector) Distinct() Selector { +func (qs *selector) Distinct() db.Selector { qs.mode = selectModeDistinct return qs } -func (qs *selector) Where(terms ...interface{}) Selector { +func (qs *selector) Where(terms ...interface{}) db.Selector { where, arguments := qs.builder.t.ToWhereWithArguments(terms) qs.where = &where qs.arguments = append(qs.arguments, arguments...) return qs } -func (qs *selector) GroupBy(columns ...interface{}) Selector { +func (qs *selector) GroupBy(columns ...interface{}) db.Selector { var fragments []exql.Fragment fragments, qs.err = columnFragments(qs.builder.t, columns) if fragments != nil { @@ -69,7 +69,7 @@ func (qs *selector) GroupBy(columns ...interface{}) Selector { return qs } -func (qs *selector) OrderBy(columns ...interface{}) Selector { +func (qs *selector) OrderBy(columns ...interface{}) db.Selector { var sortColumns exql.SortColumns for i := range columns { @@ -108,7 +108,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { return qs } -func (qs *selector) Using(columns ...interface{}) Selector { +func (qs *selector) Using(columns ...interface{}) db.Selector { if len(qs.joins) == 0 { qs.err = errors.New(`Cannot use Using() without a preceding Join() expression.`) return qs @@ -131,7 +131,7 @@ func (qs *selector) Using(columns ...interface{}) Selector { return qs } -func (qs *selector) pushJoin(t string, tables []interface{}) Selector { +func (qs *selector) pushJoin(t string, tables []interface{}) db.Selector { if qs.joins == nil { qs.joins = []*exql.Join{} } @@ -151,27 +151,27 @@ func (qs *selector) pushJoin(t string, tables []interface{}) Selector { return qs } -func (qs *selector) FullJoin(tables ...interface{}) Selector { +func (qs *selector) FullJoin(tables ...interface{}) db.Selector { return qs.pushJoin("FULL", tables) } -func (qs *selector) CrossJoin(tables ...interface{}) Selector { +func (qs *selector) CrossJoin(tables ...interface{}) db.Selector { return qs.pushJoin("CROSS", tables) } -func (qs *selector) RightJoin(tables ...interface{}) Selector { +func (qs *selector) RightJoin(tables ...interface{}) db.Selector { return qs.pushJoin("RIGHT", tables) } -func (qs *selector) LeftJoin(tables ...interface{}) Selector { +func (qs *selector) LeftJoin(tables ...interface{}) db.Selector { return qs.pushJoin("LEFT", tables) } -func (qs *selector) Join(tables ...interface{}) Selector { +func (qs *selector) Join(tables ...interface{}) db.Selector { return qs.pushJoin("", tables) } -func (qs *selector) On(terms ...interface{}) Selector { +func (qs *selector) On(terms ...interface{}) db.Selector { if len(qs.joins) == 0 { qs.err = errors.New(`Cannot use On() without a preceding Join() expression.`) return qs @@ -192,12 +192,12 @@ func (qs *selector) On(terms ...interface{}) Selector { return qs } -func (qs *selector) Limit(n int) Selector { +func (qs *selector) Limit(n int) db.Selector { qs.limit = exql.Limit(n) return qs } -func (qs *selector) Offset(n int) Selector { +func (qs *selector) Offset(n int) db.Selector { qs.offset = exql.Offset(n) return qs } @@ -224,7 +224,7 @@ func (qs *selector) QueryRow() (*sql.Row, error) { return qs.builder.sess.StatementQueryRow(qs.statement(), qs.arguments...) } -func (qs *selector) Iterator() Iterator { +func (qs *selector) Iterator() db.Iterator { rows, err := qs.builder.sess.StatementQuery(qs.statement(), qs.arguments...) return &iterator{rows, err} } diff --git a/sqlbuilder/update.go b/sqlbuilder/update.go index eac23e9d..7b6b891d 100644 --- a/sqlbuilder/update.go +++ b/sqlbuilder/update.go @@ -3,6 +3,7 @@ package builder import ( "database/sql" + "upper.io/db.v2" "upper.io/db.v2/sqlbuilder/exql" ) @@ -16,7 +17,7 @@ type updater struct { arguments []interface{} } -func (qu *updater) Set(terms ...interface{}) Updater { +func (qu *updater) Set(terms ...interface{}) db.Updater { if len(terms) == 1 { ff, vv, _ := Map(terms[0]) @@ -49,7 +50,7 @@ func (qu *updater) Set(terms ...interface{}) Updater { return qu } -func (qu *updater) Where(terms ...interface{}) Updater { +func (qu *updater) Where(terms ...interface{}) db.Updater { where, arguments := qu.builder.t.ToWhereWithArguments(terms) qu.where = &where qu.arguments = append(qu.arguments, arguments...) @@ -60,7 +61,7 @@ func (qu *updater) Exec() (sql.Result, error) { return qu.builder.sess.StatementExec(qu.statement(), qu.arguments...) } -func (qu *updater) Limit(limit int) Updater { +func (qu *updater) Limit(limit int) db.Updater { qu.limit = limit return qu } diff --git a/wrapper.go b/wrapper.go index 29a8defe..4861334a 100644 --- a/wrapper.go +++ b/wrapper.go @@ -22,47 +22,50 @@ package db import ( + "database/sql" "fmt" - "reflect" ) -// This map holds a copy of all registered adapters. -var wrappers = make(map[string]Database) +var adapters map[string]*SQLAdapter -// Register associates an adapter's name with a type. Panics if the adapter -// name is empty or the adapter is nil. -func Register(name string, adapter Database) { +func init() { + adapters = make(map[string]*SQLAdapter) +} - if name == `` { - panic(`Missing adapter name.`) - } +type SQLAdapter struct { + New func(sqlDB *sql.DB) (SQLDatabase, error) + NewTx func(sqlTx *sql.Tx) (SQLTx, error) + Open func(settings ConnectionURL) (SQLDatabase, error) +} - if _, ok := wrappers[name]; ok != false { - panic(`db.Register() called twice for adapter: ` + name) +func RegisterSQLAdapter(name string, fn *SQLAdapter) { + if name == "" { + panic(`Missing adapter name`) } - - wrappers[name] = adapter + if _, ok := adapters[name]; ok { + panic(`db.RegisterSQLAdapter() called twice for adapter: ` + name) + } + adapters[name] = fn } -// Open configures a database session using the given adapter's name and the -// provided settings. -func Open(adapter string, conn ConnectionURL) (Database, error) { - - driver, ok := wrappers[adapter] - if !ok { - // Using panic instead of returning error because attemping to use an - // adapter that does not exists will never result in success. - panic(fmt.Sprintf(`db.Open: Unknown adapter %s. (see: https://upper.io/db.v2#database-adapters)`, adapter)) +func Adapter(name string) SQLAdapter { + if fn, ok := adapters[name]; ok { + return *fn } + return missingAdapter(name) +} - // Creating a new connection everytime Open() is called. - driverType := reflect.ValueOf(driver).Elem().Type() - newAdapter := reflect.New(driverType).Interface().(Database) - - // Setting up the connection. - if err := newAdapter.Open(conn); err != nil { - return nil, err +func missingAdapter(name string) SQLAdapter { + err := fmt.Errorf("upper: Missing adapter %q, forgot to import?", name) + return SQLAdapter{ + New: func(*sql.DB) (SQLDatabase, error) { + return nil, err + }, + NewTx: func(*sql.Tx) (SQLTx, error) { + return nil, err + }, + Open: func(ConnectionURL) (SQLDatabase, error) { + return nil, err + }, } - - return newAdapter, nil } -- GitLab