From 35ecf39581e2863321eed1330d12a9fe16a005d5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Thu, 28 May 2015 06:00:17 -0500
Subject: [PATCH] Moving more shared logic to sqlutil.

---
 mysql/collection.go         | 65 +++++++++----------------------------
 mysql/database.go           |  6 ++--
 mysql/database_test.go      |  6 ++--
 postgresql/collection.go    | 22 +++----------
 postgresql/database.go      | 28 ++++------------
 postgresql/database_test.go |  2 +-
 postgresql/template.go      |  2 +-
 util/sqlutil/debug.go       | 17 ++++++++++
 util/sqlutil/main.go        | 51 +++++++++++++++++++++++++++++
 9 files changed, 103 insertions(+), 96 deletions(-)

diff --git a/mysql/collection.go b/mysql/collection.go
index 95a11462..6b53facb 100644
--- a/mysql/collection.go
+++ b/mysql/collection.go
@@ -34,22 +34,10 @@ import (
 type table struct {
 	sqlutil.T
 	*database
-	names []string
 }
 
 var _ = db.Collection(&table{})
 
-// tableN returns the nth name provided to the table.
-func (t *table) tableN(i int) string {
-	if len(t.names) > i {
-		chunks := strings.SplitN(t.names[i], " ", 2)
-		if len(chunks) > 0 {
-			return chunks[0]
-		}
-	}
-	return ""
-}
-
 // Find creates a result set with the given conditions.
 func (t *table) Find(terms ...interface{}) db.Result {
 	where, arguments := sqlutil.ToWhereWithArguments(terms)
@@ -60,7 +48,7 @@ func (t *table) Find(terms ...interface{}) db.Result {
 func (t *table) Truncate() error {
 	_, err := t.database.Exec(sqlgen.Statement{
 		Type:  sqlgen.Truncate,
-		Table: sqlgen.TableWithName(t.tableN(0)),
+		Table: sqlgen.TableWithName(t.MainTableName()),
 	})
 
 	if err != nil {
@@ -73,41 +61,19 @@ func (t *table) Truncate() error {
 func (t *table) Append(item interface{}) (interface{}, error) {
 	var pKey []string
 
-	cols, vals, err := t.FieldValues(item)
+	columnNames, columnValues, err := t.FieldValues(item)
 
 	if err != nil {
 		return nil, err
 	}
 
-	columns := new(sqlgen.Columns)
+	sqlgenCols, sqlgenVals, sqlgenArgs, err := t.ColumnsValuesAndArguments(columnNames, columnValues)
 
-	columns.Columns = make([]sqlgen.Fragment, 0, len(cols))
-	for i := range cols {
-		columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(cols[i]))
-	}
-
-	values := new(sqlgen.Values)
-	var arguments []interface{}
-
-	arguments = make([]interface{}, 0, len(vals))
-	values.Values = make([]sqlgen.Fragment, 0, len(vals))
-
-	for i := range vals {
-		switch v := vals[i].(type) {
-		case *sqlgen.Value:
-			// Adding value.
-			values.Values = append(values.Values, v)
-		case sqlgen.Value:
-			// Adding value.
-			values.Values = append(values.Values, &v)
-		default:
-			// Adding both value and placeholder.
-			values.Values = append(values.Values, sqlPlaceholder)
-			arguments = append(arguments, v)
-		}
+	if err != nil {
+		return nil, err
 	}
 
-	if pKey, err = t.database.getPrimaryKey(t.tableN(0)); err != nil {
+	if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil {
 		if err != sql.ErrNoRows {
 			// Can't tell primary key.
 			return nil, err
@@ -116,13 +82,13 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 
 	stmt := sqlgen.Statement{
 		Type:    sqlgen.Insert,
-		Table:   sqlgen.TableWithName(t.tableN(0)),
-		Columns: columns,
-		Values:  values,
+		Table:   sqlgen.TableWithName(t.MainTableName()),
+		Columns: sqlgenCols,
+		Values:  sqlgenVals,
 	}
 
 	var res sql.Result
-	if res, err = t.database.Exec(stmt, arguments...); err != nil {
+	if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil {
 		return nil, err
 	}
 
@@ -149,10 +115,10 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 	// were given for constructing the composite key.
 	keyMap := make(map[string]interface{})
 
-	for i := range cols {
+	for i := range columnNames {
 		for j := 0; j < len(pKey); j++ {
-			if pKey[j] == cols[i] {
-				keyMap[pKey[j]] = vals[i]
+			if pKey[j] == columnNames[i] {
+				keyMap[pKey[j]] = columnValues[i]
 			}
 		}
 	}
@@ -177,12 +143,13 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 
 // Returns true if the collection exists.
 func (t *table) Exists() bool {
-	if err := t.database.tableExists(t.names...); err != nil {
+	if err := t.database.tableExists(t.Tables...); err != nil {
 		return false
 	}
 	return true
 }
 
+// Name returns the name of the table or tables that form the collection.
 func (t *table) Name() string {
-	return strings.Join(t.names, `, `)
+	return strings.Join(t.Tables, `, `)
 }
diff --git a/mysql/database.go b/mysql/database.go
index 11e5e5de..389cb156 100644
--- a/mysql/database.go
+++ b/mysql/database.go
@@ -163,10 +163,8 @@ func (d *database) Collection(names ...string) (db.Collection, error) {
 		}
 	}
 
-	col := &table{
-		database: d,
-		names:    names,
-	}
+	col := &table{database: d}
+	col.Tables = names
 
 	for _, name := range names {
 		chunks := strings.SplitN(name, ` `, 2)
diff --git a/mysql/database_test.go b/mysql/database_test.go
index e0e4d3ae..624dcc20 100644
--- a/mysql/database_test.go
+++ b/mysql/database_test.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam
+// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -1470,7 +1470,7 @@ func BenchmarkAppendRawSQL(b *testing.B) {
 
 	defer sess.Close()
 
-	driver := sess.Driver().(*sql.DB)
+	driver := sess.Driver().(*sqlx.DB)
 
 	if _, err = driver.Exec("TRUNCATE TABLE `artist`"); err != nil {
 		b.Fatal(err)
@@ -1524,7 +1524,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) {
 
 	defer sess.Close()
 
-	driver := sess.Driver().(*sql.DB)
+	driver := sess.Driver().(*sqlx.DB)
 
 	if tx, err = driver.Begin(); err != nil {
 		b.Fatal(err)
diff --git a/postgresql/collection.go b/postgresql/collection.go
index abdaeb55..19ecda73 100644
--- a/postgresql/collection.go
+++ b/postgresql/collection.go
@@ -37,22 +37,10 @@ type table struct {
 	sqlutil.T
 	*database
 	primaryKey string
-	names      []string
 }
 
 var _ = db.Collection(&table{})
 
-// tableN returns the nth name provided to the table.
-func (t *table) tableN(i int) string {
-	if len(t.names) > i {
-		chunks := strings.SplitN(t.names[i], " ", 2)
-		if len(chunks) > 0 {
-			return chunks[0]
-		}
-	}
-	return ""
-}
-
 // Find creates a result set with the given conditions.
 func (t *table) Find(terms ...interface{}) db.Result {
 	where, arguments := sqlutil.ToWhereWithArguments(terms)
@@ -63,7 +51,7 @@ func (t *table) Find(terms ...interface{}) db.Result {
 func (t *table) Truncate() error {
 	_, err := t.database.Exec(sqlgen.Statement{
 		Type:  sqlgen.Truncate,
-		Table: sqlgen.TableWithName(t.tableN(0)),
+		Table: sqlgen.TableWithName(t.MainTableName()),
 	})
 
 	if err != nil {
@@ -112,7 +100,7 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 
 	var pKey []string
 
-	if pKey, err = t.database.getPrimaryKey(t.tableN(0)); err != nil {
+	if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil {
 		if err != sql.ErrNoRows {
 			// Can't tell primary key.
 			return nil, err
@@ -121,7 +109,7 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 
 	stmt := sqlgen.Statement{
 		Type:    sqlgen.Insert,
-		Table:   sqlgen.TableWithName(t.tableN(0)),
+		Table:   sqlgen.TableWithName(t.MainTableName()),
 		Columns: columns,
 		Values:  values,
 	}
@@ -194,7 +182,7 @@ func (t *table) Append(item interface{}) (interface{}, error) {
 
 // Exists returns true if the collection exists.
 func (t *table) Exists() bool {
-	if err := t.database.tableExists(t.names...); err != nil {
+	if err := t.database.tableExists(t.Tables...); err != nil {
 		return false
 	}
 	return true
@@ -202,5 +190,5 @@ func (t *table) Exists() bool {
 
 // Name returns the name of the table or tables that form the collection.
 func (t *table) Name() string {
-	return strings.Join(t.names, `, `)
+	return strings.Join(t.Tables, `, `)
 }
diff --git a/postgresql/database.go b/postgresql/database.go
index 18a79d41..0ece2ade 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -24,13 +24,12 @@ package postgresql
 import (
 	"database/sql"
 	"fmt"
-	"os"
 	"strconv"
 	"strings"
 	"time"
 
 	"github.com/jmoiron/sqlx"
-	_ "github.com/lib/pq" // Go PostgreSQL driver.
+	_ "github.com/lib/pq" // PostgreSQL driver.
 	"upper.io/db"
 	"upper.io/db/util/schema"
 	"upper.io/db/util/sqlgen"
@@ -64,26 +63,12 @@ type columnSchemaT struct {
 	DataType string `db:"data_type"`
 }
 
-func debugEnabled() bool {
-	if os.Getenv(db.EnvEnableDebug) != "" {
-		return true
-	}
-	return false
-}
-
-func debugLog(query string, args []interface{}, err error, start int64, end int64) {
-	if debugEnabled() == true {
-		d := sqlutil.Debug{query, args, err, start, end}
-		d.Print()
-	}
-}
-
 // Driver returns the underlying *sqlx.DB instance.
 func (d *database) Driver() interface{} {
 	return d.session
 }
 
-// Open attempts to connect to the PostgreSQL server using already stored settings.
+// Open attempts to connect to the database server using already stored settings.
 func (d *database) Open() error {
 	var err error
 
@@ -164,9 +149,10 @@ func (d *database) Collection(names ...string) (db.Collection, error) {
 
 	col := &table{
 		database: d,
-		names:    names,
 	}
 
+	col.Tables = names
+
 	for _, name := range names {
 		chunks := strings.SplitN(name, ` `, 2)
 
@@ -311,7 +297,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result,
 
 	defer func() {
 		end = time.Now().UnixNano()
-		debugLog(query, args, err, start, end)
+		sqlutil.Log(query, args, err, start, end)
 	}()
 
 	if d.session == nil {
@@ -345,7 +331,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows
 
 	defer func() {
 		end = time.Now().UnixNano()
-		debugLog(query, args, err, start, end)
+		sqlutil.Log(query, args, err, start, end)
 	}()
 
 	if d.session == nil {
@@ -379,7 +365,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R
 
 	defer func() {
 		end = time.Now().UnixNano()
-		debugLog(query, args, err, start, end)
+		sqlutil.Log(query, args, err, start, end)
 	}()
 
 	if d.session == nil {
diff --git a/postgresql/database_test.go b/postgresql/database_test.go
index 0e7e9350..090779c4 100644
--- a/postgresql/database_test.go
+++ b/postgresql/database_test.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam
+// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
diff --git a/postgresql/template.go b/postgresql/template.go
index 7c8f13a4..6df9d1c4 100644
--- a/postgresql/template.go
+++ b/postgresql/template.go
@@ -1,4 +1,4 @@
-// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam
+// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
diff --git a/util/sqlutil/debug.go b/util/sqlutil/debug.go
index 08d8ebe9..f82a9857 100644
--- a/util/sqlutil/debug.go
+++ b/util/sqlutil/debug.go
@@ -24,7 +24,10 @@ package sqlutil
 import (
 	"fmt"
 	"log"
+	"os"
 	"strings"
+
+	"upper.io/db"
 )
 
 // Debug is used for printing SQL queries and arguments.
@@ -59,3 +62,17 @@ func (d *Debug) Print() {
 
 	log.Printf("\n\t%s\n\n", strings.Join(s, "\n\t"))
 }
+
+func IsDebugEnabled() bool {
+	if os.Getenv(db.EnvEnableDebug) != "" {
+		return true
+	}
+	return false
+}
+
+func Log(query string, args []interface{}, err error, start int64, end int64) {
+	if IsDebugEnabled() {
+		d := Debug{query, args, err, start, end}
+		d.Print()
+	}
+}
diff --git a/util/sqlutil/main.go b/util/sqlutil/main.go
index 77cc0637..11e5e266 100644
--- a/util/sqlutil/main.go
+++ b/util/sqlutil/main.go
@@ -33,6 +33,7 @@ import (
 
 	"upper.io/db"
 	"upper.io/db/util"
+	"upper.io/db/util/sqlgen"
 )
 
 var (
@@ -50,6 +51,7 @@ var (
 // using FieldValues()
 type T struct {
 	Columns []string
+	Tables  []string
 }
 
 func (t *T) columnLike(s string) string {
@@ -210,3 +212,52 @@ func NewMapper() *reflectx.Mapper {
 
 	return reflectx.NewMapperTagFunc("db", mapFunc, tagFunc)
 }
+
+// MainTableName returns the name of the first table.
+func (t *T) MainTableName() string {
+	return t.NthTableName(0)
+}
+
+// NthTableName returns the table name at index i.
+func (t *T) NthTableName(i int) string {
+	if len(t.Tables) > i {
+		chunks := strings.SplitN(t.Tables[i], " ", 2)
+		if len(chunks) > 0 {
+			return chunks[0]
+		}
+	}
+	return ""
+}
+
+func (t *T) ColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) {
+	var arguments []interface{}
+
+	columns := new(sqlgen.Columns)
+
+	columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames))
+	for i := range columnNames {
+		columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i]))
+	}
+
+	values := new(sqlgen.Values)
+
+	arguments = make([]interface{}, 0, len(columnValues))
+	values.Values = make([]sqlgen.Fragment, 0, len(columnValues))
+
+	for i := range columnValues {
+		switch v := columnValues[i].(type) {
+		case *sqlgen.Value:
+			// Adding value.
+			values.Values = append(values.Values, v)
+		case sqlgen.Value:
+			// Adding value.
+			values.Values = append(values.Values, &v)
+		default:
+			// Adding both value and placeholder.
+			values.Values = append(values.Values, sqlPlaceholder)
+			arguments = append(arguments, v)
+		}
+	}
+
+	return columns, values, arguments, nil
+}
-- 
GitLab