From 0f065e510a33d9f1b952f097ab33b94725e899fa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 6 Oct 2015 04:16:16 -0500
Subject: [PATCH] Moving collection's shared logic to an internal package.

---
 builder.go                        |  2 ++
 builder/builder.go                |  9 +++++++
 internal/sqladapter/collection.go | 39 ++++++++++++++++++++++++++++
 internal/sqladapter/database.go   | 11 ++++++++
 postgresql/collection.go          | 43 ++++++++++---------------------
 postgresql/database.go            | 28 +++++++-------------
 6 files changed, 83 insertions(+), 49 deletions(-)
 create mode 100644 internal/sqladapter/collection.go
 create mode 100644 internal/sqladapter/database.go

diff --git a/builder.go b/builder.go
index 3dc52af6..7dc52ba9 100644
--- a/builder.go
+++ b/builder.go
@@ -14,6 +14,8 @@ type QueryBuilder interface {
 	InsertInto(table string) QueryInserter
 	DeleteFrom(table string) QueryDeleter
 	Update(table string) QueryUpdater
+
+	Exec(query interface{}, args ...interface{}) (sql.Result, error)
 }
 
 type QuerySelector interface {
diff --git a/builder/builder.go b/builder/builder.go
index a895c84f..1a9d026c 100644
--- a/builder/builder.go
+++ b/builder/builder.go
@@ -63,6 +63,15 @@ type Builder struct {
 	t    *sqlutil.TemplateWithUtils
 }
 
+func (b *Builder) Exec(query interface{}, args ...interface{}) (sql.Result, error) {
+	switch q := query.(type) {
+	case *sqlgen.Statement:
+		return b.sess.Exec(q, args...)
+	default:
+		return nil, errors.New("Unsupported query type.")
+	}
+}
+
 func (b *Builder) TruncateTable(table string) db.QueryTruncater {
 	qs := &QueryTruncater{
 		builder: b,
diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go
new file mode 100644
index 00000000..bd13ec55
--- /dev/null
+++ b/internal/sqladapter/collection.go
@@ -0,0 +1,39 @@
+package sqladapter
+
+import (
+	"upper.io/db"
+	"upper.io/db/util/sqlutil/result"
+)
+
+type Collection struct {
+	database  Database
+	tableName string
+}
+
+// NewCollection returns a collection with basic methods.
+func NewCollection(d Database, tableName string) *Collection {
+	return &Collection{database: d, tableName: tableName}
+}
+
+// Name returns the name of the table.
+func (c *Collection) Name() string {
+	return c.tableName
+}
+
+// Exists returns true if the collection exists.
+func (c *Collection) Exists() bool {
+	if err := c.Database().TableExists(c.Name()); err != nil {
+		return false
+	}
+	return true
+}
+
+// Find creates a result set with the given conditions.
+func (c *Collection) Find(conds ...interface{}) db.Result {
+	return result.NewResult(c.Database().Builder(), c.Name(), conds)
+}
+
+// Database returns the database session that backs the collection.
+func (c *Collection) Database() Database {
+	return c.database
+}
diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go
new file mode 100644
index 00000000..b12869da
--- /dev/null
+++ b/internal/sqladapter/database.go
@@ -0,0 +1,11 @@
+package sqladapter
+
+import (
+	"upper.io/db"
+)
+
+type Database interface {
+	db.Database
+	TableExists(name string) error
+	TablePrimaryKey(name string) ([]string, error)
+}
diff --git a/postgresql/collection.go b/postgresql/collection.go
index d5a9fb37..14378bdf 100644
--- a/postgresql/collection.go
+++ b/postgresql/collection.go
@@ -28,22 +28,16 @@ import (
 
 	"upper.io/db"
 	"upper.io/db/builder"
+	"upper.io/db/internal/sqladapter"
 	"upper.io/db/util/sqlgen"
-	"upper.io/db/util/sqlutil/result"
 )
 
 type table struct {
-	*database
-	name string
+	*sqladapter.Collection
 }
 
 var _ = db.Collection(&table{})
 
-// Find creates a result set with the given conditions.
-func (t *table) Find(conds ...interface{}) db.Result {
-	return result.NewResult(t.database.Builder(), t.Name(), conds)
-}
-
 // Truncate deletes all rows from the table.
 func (t *table) Truncate() error {
 	stmt := sqlgen.Statement{
@@ -52,7 +46,7 @@ func (t *table) Truncate() error {
 		Extra: sqlgen.Extra("RESTART IDENTITY"),
 	}
 
-	if _, err := t.database.Exec(&stmt); err != nil {
+	if _, err := t.Database().Builder().Exec(&stmt); err != nil {
 		return err
 	}
 	return nil
@@ -66,37 +60,35 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 	}
 
 	var pKey []string
-	if pKey, err = t.database.getPrimaryKey(t.Name()); err != nil {
+	if pKey, err = t.Database().TablePrimaryKey(t.Name()); err != nil {
 		if err != sql.ErrNoRows {
-			// Can't tell primary key.
 			return nil, err
 		}
 	}
 
-	q := t.database.Builder().InsertInto(t.Name()).
+	q := t.Database().Builder().InsertInto(t.Name()).
 		Columns(columnNames...).
 		Values(columnValues...)
 
-	// No primary keys defined.
 	if len(pKey) == 0 {
-
+		// There is no primary key.
 		var res sql.Result
+
 		if res, err = q.Exec(); err != nil {
 			return nil, err
 		}
 
-		// Attempt to use LastInsertId() (probably won't work, but the exec()
-		// succeeded, so the error from LastInsertId() is ignored).
+		// Attempt to use LastInsertId() (probably won't work, but the Exec()
+		// succeeded, so we can safely ignore the error from LastInsertId()).
 		lastID, _ := res.LastInsertId()
 
 		return lastID, nil
 	}
 
-	// A primary key was found.
+	// Asking the database to return the primary key after insertion.
 	q.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`)))
 
 	var keyMap map[string]interface{}
-
 	if err = q.Iterator().One(&keyMap); err != nil {
 		return nil, err
 	}
@@ -133,19 +125,10 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 		return id.(int64), nil
 	}
 
-	// More than one key, no interface matched, let's return a map.
+	// This was a compound key and no interface matched it, let's return a map.
 	return keyMap, nil
 }
 
-// Exists returns true if the collection exists.
-func (t *table) Exists() bool {
-	if err := t.database.tableExists(t.Name()); err != nil {
-		return false
-	}
-	return true
-}
-
-// Name returns the name of the table or tables that form the collection.
-func (t *table) Name() string {
-	return t.name
+func newTable(d *database, name string) *table {
+	return &table{sqladapter.NewCollection(d, name)}
 }
diff --git a/postgresql/database.go b/postgresql/database.go
index 010c3507..2a786349 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -211,7 +211,7 @@ func (d *database) C(name string) db.Collection {
 	return c
 }
 
-// Collection returns a table by name.
+// Collection returns the table that matches the given name.
 func (d *database) Collection(name string) (db.Collection, error) {
 	if d.tx != nil {
 		if d.tx.Done() {
@@ -219,11 +219,11 @@ func (d *database) Collection(name string) (db.Collection, error) {
 		}
 	}
 
-	if err := d.tableExists(name); err != nil {
+	if err := d.TableExists(name); err != nil {
 		return nil, err
 	}
 
-	col := &table{database: d, name: name}
+	col := newTable(d, name)
 
 	d.collectionsMu.Lock()
 	d.collections[name] = col
@@ -424,19 +424,13 @@ func (d *database) populateSchema() (err error) {
 	return err
 }
 
-func (d *database) tableExists(names ...string) error {
-	var row map[string]string
-
-	for _, tableName := range names {
-
-		if d.schema.HasTable(tableName) {
-			// We already know this table exists.
-			continue
-		}
-
+func (d *database) TableExists(name string) error {
+	if !d.schema.HasTable(name) {
 		q := d.Builder().Select("table_name").
 			From("information_schema.tables").
-			Where("table_catalog = ? AND table_name = ?", d.schema.Name, tableName)
+			Where("table_catalog = ? AND table_name = ?", d.schema.Name, name)
+
+		var row map[string]string
 
 		if err := q.Iterator().One(&row); err != nil {
 			return db.ErrCollectionDoesNotExist
@@ -447,10 +441,6 @@ func (d *database) tableExists(names ...string) error {
 }
 
 func (d *database) TableColumns(tableName string) ([]string, error) {
-	return d.tableColumns(tableName)
-}
-
-func (d *database) tableColumns(tableName string) ([]string, error) {
 
 	tableSchema := d.schema.Table(tableName)
 
@@ -477,7 +467,7 @@ func (d *database) tableColumns(tableName string) ([]string, error) {
 	return d.schema.TableInfo[tableName].Columns, nil
 }
 
-func (d *database) getPrimaryKey(tableName string) ([]string, error) {
+func (d *database) TablePrimaryKey(tableName string) ([]string, error) {
 	tableSchema := d.schema.Table(tableName)
 
 	if len(tableSchema.PrimaryKey) != 0 {
-- 
GitLab