From facff58b3b2d069166a997c77e830c69863f1d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Wed, 7 Oct 2015 07:31:07 -0500 Subject: [PATCH] Adding and using Iterator.Scan() and making the argument to Iterator.Next() optional. --- builder.go | 7 ++- builder/builder.go | 31 ++++++++-- error.go | 3 - internal/sqladapter/database.go | 2 +- postgresql/database.go | 103 ++++++++++++++++++++------------ 5 files changed, 96 insertions(+), 50 deletions(-) diff --git a/builder.go b/builder.go index 7dc52ba9..c04b4719 100644 --- a/builder.go +++ b/builder.go @@ -82,9 +82,10 @@ type QueryGetter interface { } type Iterator interface { - All(interface{}) error - One(interface{}) error - Next(interface{}) bool + All(dest interface{}) error + One(dest interface{}) error + Scan(dest ...interface{}) error + Next(dest ...interface{}) bool Err() error Close() error } diff --git a/builder/builder.go b/builder/builder.go index 1a9d026c..7cebb281 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -637,6 +637,13 @@ func NewBuilder(sess sqlDatabase, t *sqlgen.Template) *Builder { } } +func (iter *iterator) Scan(dst ...interface{}) error { + if iter.err != nil { + return iter.err + } + return iter.cursor.Scan(dst...) +} + func (iter *iterator) One(dst interface{}) error { if iter.err != nil { return iter.err @@ -670,20 +677,32 @@ func (iter *iterator) Err() (err error) { return iter.err } -func (iter *iterator) Next(dst interface{}) bool { +func (iter *iterator) Next(dst ...interface{}) bool { var err error if iter.err != nil { return false } - if err = sqlutil.FetchRow(iter.cursor, dst); err != nil { - iter.err = err - iter.Close() - return false + switch len(dst) { + case 0: + if ok := iter.cursor.Next(); !ok { + iter.err = iter.cursor.Err() + iter.Close() + return false + } + return true + case 1: + if err = sqlutil.FetchRow(iter.cursor, dst[0]); err != nil { + iter.err = err + iter.Close() + return false + } + return true } - return true + iter.err = db.ErrUnsupported + return false } func (iter *iterator) Close() (err error) { diff --git a/error.go b/error.go index 01f2b441..3bd51dcf 100644 --- a/error.go +++ b/error.go @@ -51,6 +51,3 @@ var ( ErrTooManyClients = errors.New(`Can't connect to database server: too many clients.`) ErrGivingUpTryingToConnect = errors.New(`Giving up trying to connect: too many clients.`) ) - -// Deprecated but kept for backwards compatibility. See: https://github.com/upper/db/issues/18 -var ErrCollectionDoesNotExists = ErrCollectionDoesNotExist diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 5ac62d88..cdf408fb 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -6,7 +6,6 @@ import ( "time" "github.com/jmoiron/sqlx" - _ "github.com/lib/pq" // PostgreSQL driver. "upper.io/cache" "upper.io/db" "upper.io/db/builder" @@ -141,6 +140,7 @@ func (d *BaseDatabase) C(name string) db.Collection { if err != nil { return &adapter.NonExistentCollection{Err: err} } + return c } diff --git a/postgresql/database.go b/postgresql/database.go index 91360c59..9c249779 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -38,6 +38,9 @@ type database struct { var _ = db.Database(&database{}) +// CompileAndReplacePlaceholders compiles the given statement into an string +// and replaces each generic placeholder with the placeholder the driver +// expects (if any). func (d *database) CompileAndReplacePlaceholders(stmt *sqlgen.Statement) (query string) { buf := stmt.Compile(d.Template()) @@ -54,15 +57,21 @@ func (d *database) CompileAndReplacePlaceholders(stmt *sqlgen.Statement) (query return query } +// Err translates some known errors into generic errors. func (d *database) Err(err error) error { - s := err.Error() - if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) { - return db.ErrTooManyClients + if err != nil { + s := err.Error() + if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) { + return db.ErrTooManyClients + } + if strings.Contains(s, `relation`) && strings.Contains(s, `does not exist`) { + return db.ErrCollectionDoesNotExist + } } return err } -// Open attempts to connect to the database server using already stored settings. +// Open attempts to open a connection to the database server. func (d *database) Open() error { var sess *sqlx.DB @@ -78,10 +87,8 @@ func (d *database) Open() error { return d.Bind(sess) } +// Setup configures the adapter. func (d *database) Setup(connURL db.ConnectionURL) error { - if d.BaseDatabase != nil { - d.Close() - } d.BaseDatabase = sqladapter.NewDatabase(d, connURL, template.Template) return d.Open() } @@ -93,22 +100,19 @@ func (d *database) Use(name string) (err error) { return err } conn.Database = name - return d.Setup(conn) -} - -func (d *database) clone() (*database, error) { - clone := &database{} - clone.BaseDatabase = d.BaseDatabase.Clone(clone) - if err := clone.Open(); err != nil { - return nil, err + if d.BaseDatabase != nil { + d.Close() } - return clone, nil + return d.Setup(conn) } +// Clone creates a new database connection with the same settings as the +// original. func (d *database) Clone() (db.Database, error) { return d.clone() } +// NewTable returns a db.Collection. func (d *database) NewTable(name string) db.Collection { return newTable(d, name) } @@ -121,13 +125,15 @@ func (d *database) Collections() (collections []string, err error) { From("information_schema.tables"). Where("table_schema = ?", "public") - var row struct { - TableName string `db:"table_name"` - } - iter := q.Iterator() - for iter.Next(&row) { - d.Schema().AddTable(row.TableName) + defer iter.Close() + + for iter.Next() { + var tableName string + if err := iter.Scan(&tableName); err != nil { + return nil, err + } + d.Schema().AddTable(tableName) } } @@ -178,18 +184,22 @@ func (d *database) PopulateSchema() (err error) { d.NewSchema() - // Get database name. q := d.Builder().Select(db.Raw{"CURRENT_DATABASE() AS name"}) - var row struct { - Name string `db:"name"` - } + var dbName string - if err := q.Iterator().One(&row); err != nil { - return err + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + if err := iter.Scan(&dbName); err != nil { + return err + } + } else { + return iter.Err() } - d.Schema().Name = row.Name + d.Schema().Name = dbName if collections, err = d.Collections(); err != nil { return err @@ -204,6 +214,7 @@ func (d *database) PopulateSchema() (err error) { return err } +// TableExists checks whether a table exists and returns an error in case it doesn't. func (d *database) TableExists(name string) error { if d.Schema().HasTable(name) { return nil @@ -213,15 +224,22 @@ func (d *database) TableExists(name string) error { From("information_schema.tables"). Where("table_catalog = ? AND table_name = ?", d.Schema().Name, name) - var row map[string]string + iter := q.Iterator() + defer iter.Close() - if err := q.Iterator().One(&row); err != nil { + if iter.Next() { + var tableName string + if err := iter.Scan(&tableName); err != nil { + return err + } + } else { return db.ErrCollectionDoesNotExist } return nil } +// TableColumns returns all columns from the given table. func (d *database) TableColumns(tableName string) ([]string, error) { s := d.Schema() @@ -249,6 +267,7 @@ func (d *database) TableColumns(tableName string) ([]string, error) { return s.Table(tableName).Columns, nil } +// TablePrimaryKey returns all primary keys from the given table. func (d *database) TablePrimaryKey(tableName string) ([]string, error) { s := d.Schema() @@ -271,14 +290,24 @@ func (d *database) TablePrimaryKey(tableName string) ([]string, error) { `).OrderBy("pkey") iter := q.Iterator() + defer iter.Close() - var row struct { - Key string `db:"pkey"` - } - - for iter.Next(&row) { - ts.PrimaryKey = append(ts.PrimaryKey, row.Key) + for iter.Next() { + var pKey string + if err := iter.Scan(&pKey); err != nil { + return nil, err + } + ts.PrimaryKey = append(ts.PrimaryKey, pKey) } return ts.PrimaryKey, nil } + +func (d *database) clone() (*database, error) { + clone := &database{} + clone.BaseDatabase = d.BaseDatabase.Clone(clone) + if err := clone.Open(); err != nil { + return nil, err + } + return clone, nil +} -- GitLab