From facff58b3b2d069166a997c77e830c69863f1d28 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Wed, 7 Oct 2015 07:31:07 -0500
Subject: [PATCH] Adding and using Iterator.Scan() and making the argument to
 Iterator.Next() optional.

---
 builder.go                      |   7 ++-
 builder/builder.go              |  31 ++++++++--
 error.go                        |   3 -
 internal/sqladapter/database.go |   2 +-
 postgresql/database.go          | 103 ++++++++++++++++++++------------
 5 files changed, 96 insertions(+), 50 deletions(-)

diff --git a/builder.go b/builder.go
index 7dc52ba9..c04b4719 100644
--- a/builder.go
+++ b/builder.go
@@ -82,9 +82,10 @@ type QueryGetter interface {
 }
 
 type Iterator interface {
-	All(interface{}) error
-	One(interface{}) error
-	Next(interface{}) bool
+	All(dest interface{}) error
+	One(dest interface{}) error
+	Scan(dest ...interface{}) error
+	Next(dest ...interface{}) bool
 	Err() error
 	Close() error
 }
diff --git a/builder/builder.go b/builder/builder.go
index 1a9d026c..7cebb281 100644
--- a/builder/builder.go
+++ b/builder/builder.go
@@ -637,6 +637,13 @@ func NewBuilder(sess sqlDatabase, t *sqlgen.Template) *Builder {
 	}
 }
 
+func (iter *iterator) Scan(dst ...interface{}) error {
+	if iter.err != nil {
+		return iter.err
+	}
+	return iter.cursor.Scan(dst...)
+}
+
 func (iter *iterator) One(dst interface{}) error {
 	if iter.err != nil {
 		return iter.err
@@ -670,20 +677,32 @@ func (iter *iterator) Err() (err error) {
 	return iter.err
 }
 
-func (iter *iterator) Next(dst interface{}) bool {
+func (iter *iterator) Next(dst ...interface{}) bool {
 	var err error
 
 	if iter.err != nil {
 		return false
 	}
 
-	if err = sqlutil.FetchRow(iter.cursor, dst); err != nil {
-		iter.err = err
-		iter.Close()
-		return false
+	switch len(dst) {
+	case 0:
+		if ok := iter.cursor.Next(); !ok {
+			iter.err = iter.cursor.Err()
+			iter.Close()
+			return false
+		}
+		return true
+	case 1:
+		if err = sqlutil.FetchRow(iter.cursor, dst[0]); err != nil {
+			iter.err = err
+			iter.Close()
+			return false
+		}
+		return true
 	}
 
-	return true
+	iter.err = db.ErrUnsupported
+	return false
 }
 
 func (iter *iterator) Close() (err error) {
diff --git a/error.go b/error.go
index 01f2b441..3bd51dcf 100644
--- a/error.go
+++ b/error.go
@@ -51,6 +51,3 @@ var (
 	ErrTooManyClients          = errors.New(`Can't connect to database server: too many clients.`)
 	ErrGivingUpTryingToConnect = errors.New(`Giving up trying to connect: too many clients.`)
 )
-
-// Deprecated but kept for backwards compatibility. See: https://github.com/upper/db/issues/18
-var ErrCollectionDoesNotExists = ErrCollectionDoesNotExist
diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go
index 5ac62d88..cdf408fb 100644
--- a/internal/sqladapter/database.go
+++ b/internal/sqladapter/database.go
@@ -6,7 +6,6 @@ import (
 	"time"
 
 	"github.com/jmoiron/sqlx"
-	_ "github.com/lib/pq" // PostgreSQL driver.
 	"upper.io/cache"
 	"upper.io/db"
 	"upper.io/db/builder"
@@ -141,6 +140,7 @@ func (d *BaseDatabase) C(name string) db.Collection {
 	if err != nil {
 		return &adapter.NonExistentCollection{Err: err}
 	}
+
 	return c
 }
 
diff --git a/postgresql/database.go b/postgresql/database.go
index 91360c59..9c249779 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -38,6 +38,9 @@ type database struct {
 
 var _ = db.Database(&database{})
 
+// CompileAndReplacePlaceholders compiles the given statement into an string
+// and replaces each generic placeholder with the placeholder the driver
+// expects (if any).
 func (d *database) CompileAndReplacePlaceholders(stmt *sqlgen.Statement) (query string) {
 	buf := stmt.Compile(d.Template())
 
@@ -54,15 +57,21 @@ func (d *database) CompileAndReplacePlaceholders(stmt *sqlgen.Statement) (query
 	return query
 }
 
+// Err translates some known errors into generic errors.
 func (d *database) Err(err error) error {
-	s := err.Error()
-	if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) {
-		return db.ErrTooManyClients
+	if err != nil {
+		s := err.Error()
+		if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) {
+			return db.ErrTooManyClients
+		}
+		if strings.Contains(s, `relation`) && strings.Contains(s, `does not exist`) {
+			return db.ErrCollectionDoesNotExist
+		}
 	}
 	return err
 }
 
-// Open attempts to connect to the database server using already stored settings.
+// Open attempts to open a connection to the database server.
 func (d *database) Open() error {
 	var sess *sqlx.DB
 
@@ -78,10 +87,8 @@ func (d *database) Open() error {
 	return d.Bind(sess)
 }
 
+// Setup configures the adapter.
 func (d *database) Setup(connURL db.ConnectionURL) error {
-	if d.BaseDatabase != nil {
-		d.Close()
-	}
 	d.BaseDatabase = sqladapter.NewDatabase(d, connURL, template.Template)
 	return d.Open()
 }
@@ -93,22 +100,19 @@ func (d *database) Use(name string) (err error) {
 		return err
 	}
 	conn.Database = name
-	return d.Setup(conn)
-}
-
-func (d *database) clone() (*database, error) {
-	clone := &database{}
-	clone.BaseDatabase = d.BaseDatabase.Clone(clone)
-	if err := clone.Open(); err != nil {
-		return nil, err
+	if d.BaseDatabase != nil {
+		d.Close()
 	}
-	return clone, nil
+	return d.Setup(conn)
 }
 
+// Clone creates a new database connection with the same settings as the
+// original.
 func (d *database) Clone() (db.Database, error) {
 	return d.clone()
 }
 
+// NewTable returns a db.Collection.
 func (d *database) NewTable(name string) db.Collection {
 	return newTable(d, name)
 }
@@ -121,13 +125,15 @@ func (d *database) Collections() (collections []string, err error) {
 			From("information_schema.tables").
 			Where("table_schema = ?", "public")
 
-		var row struct {
-			TableName string `db:"table_name"`
-		}
-
 		iter := q.Iterator()
-		for iter.Next(&row) {
-			d.Schema().AddTable(row.TableName)
+		defer iter.Close()
+
+		for iter.Next() {
+			var tableName string
+			if err := iter.Scan(&tableName); err != nil {
+				return nil, err
+			}
+			d.Schema().AddTable(tableName)
 		}
 	}
 
@@ -178,18 +184,22 @@ func (d *database) PopulateSchema() (err error) {
 
 	d.NewSchema()
 
-	// Get database name.
 	q := d.Builder().Select(db.Raw{"CURRENT_DATABASE() AS name"})
 
-	var row struct {
-		Name string `db:"name"`
-	}
+	var dbName string
 
-	if err := q.Iterator().One(&row); err != nil {
-		return err
+	iter := q.Iterator()
+	defer iter.Close()
+
+	if iter.Next() {
+		if err := iter.Scan(&dbName); err != nil {
+			return err
+		}
+	} else {
+		return iter.Err()
 	}
 
-	d.Schema().Name = row.Name
+	d.Schema().Name = dbName
 
 	if collections, err = d.Collections(); err != nil {
 		return err
@@ -204,6 +214,7 @@ func (d *database) PopulateSchema() (err error) {
 	return err
 }
 
+// TableExists checks whether a table exists and returns an error in case it doesn't.
 func (d *database) TableExists(name string) error {
 	if d.Schema().HasTable(name) {
 		return nil
@@ -213,15 +224,22 @@ func (d *database) TableExists(name string) error {
 		From("information_schema.tables").
 		Where("table_catalog = ? AND table_name = ?", d.Schema().Name, name)
 
-	var row map[string]string
+	iter := q.Iterator()
+	defer iter.Close()
 
-	if err := q.Iterator().One(&row); err != nil {
+	if iter.Next() {
+		var tableName string
+		if err := iter.Scan(&tableName); err != nil {
+			return err
+		}
+	} else {
 		return db.ErrCollectionDoesNotExist
 	}
 
 	return nil
 }
 
+// TableColumns returns all columns from the given table.
 func (d *database) TableColumns(tableName string) ([]string, error) {
 	s := d.Schema()
 
@@ -249,6 +267,7 @@ func (d *database) TableColumns(tableName string) ([]string, error) {
 	return s.Table(tableName).Columns, nil
 }
 
+// TablePrimaryKey returns all primary keys from the given table.
 func (d *database) TablePrimaryKey(tableName string) ([]string, error) {
 	s := d.Schema()
 
@@ -271,14 +290,24 @@ func (d *database) TablePrimaryKey(tableName string) ([]string, error) {
 		`).OrderBy("pkey")
 
 	iter := q.Iterator()
+	defer iter.Close()
 
-	var row struct {
-		Key string `db:"pkey"`
-	}
-
-	for iter.Next(&row) {
-		ts.PrimaryKey = append(ts.PrimaryKey, row.Key)
+	for iter.Next() {
+		var pKey string
+		if err := iter.Scan(&pKey); err != nil {
+			return nil, err
+		}
+		ts.PrimaryKey = append(ts.PrimaryKey, pKey)
 	}
 
 	return ts.PrimaryKey, nil
 }
+
+func (d *database) clone() (*database, error) {
+	clone := &database{}
+	clone.BaseDatabase = d.BaseDatabase.Clone(clone)
+	if err := clone.Open(); err != nil {
+		return nil, err
+	}
+	return clone, nil
+}
-- 
GitLab