From 4efce5d94d14fbedba8e432601e905c3ef4530ba Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 16 Sep 2014 07:49:13 -0500
Subject: [PATCH] PostgreSQL: Adding support for upper.io/db/util/schema.

---
 postgresql/database.go | 249 +++++++++++++++++++++++++++--------------
 1 file changed, 163 insertions(+), 86 deletions(-)

diff --git a/postgresql/database.go b/postgresql/database.go
index cbff654a..f5d5a8ab 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -25,7 +25,6 @@ import (
 	"database/sql"
 	"fmt"
 	"os"
-	"regexp"
 	"strconv"
 	"strings"
 	"time"
@@ -33,6 +32,7 @@ import (
 	_ "github.com/xiam/gopostgresql"
 	"upper.io/cache"
 	"upper.io/db"
+	"upper.io/db/util/schema"
 	"upper.io/db/util/sqlgen"
 	"upper.io/db/util/sqlutil"
 )
@@ -50,19 +50,17 @@ var (
 var template *sqlgen.Template
 
 var (
-	columnPattern  = regexp.MustCompile(`^([a-z]+)\(?([0-9,]+)?\)?\s?([a-z]*)?`)
 	sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}}
 )
 
 type source struct {
-	config      db.Settings
-	session     *sql.DB
-	collections map[string]db.Collection
-	tx          *sql.Tx
-	primaryKeys map[string]string
+	config  db.Settings
+	session *sql.DB
+	tx      *sql.Tx
+	schema  *schema.DatabaseSchema
 }
 
-type columnSchema_t struct {
+type columnSchemaT struct {
 	Name string `db:"column_name"`
 }
 
@@ -117,6 +115,28 @@ func init() {
 	db.Register(Adapter, &source{})
 }
 
+func (self *source) populateSchema() (err error) {
+	var collections []string
+
+	self.schema = schema.NewDatabaseSchema()
+
+	self.schema.Name = self.config.Database
+
+	// The Collections() call will populate schema if its nil.
+	if collections, err = self.Collections(); err != nil {
+		return err
+	}
+
+	for i := range collections {
+		// Populate each collection.
+		if _, err = self.Collection(collections[i]); err != nil {
+			return err
+		}
+	}
+
+	return err
+}
+
 func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) {
 	var query string
 	var res sql.Result
@@ -265,7 +285,7 @@ func (self *source) Transaction() (db.Tx, error) {
 // Stores database settings.
 func (self *source) Setup(config db.Settings) error {
 	self.config = config
-	self.collections = make(map[string]db.Collection)
+	//self.collections = make(map[string]db.Collection)
 	return self.Open()
 }
 
@@ -310,8 +330,11 @@ func (self *source) Open() error {
 	}
 	conn += fmt.Sprintf(`dbname=%s sslmode=%s`, self.config.Database, SSLMode)
 
-	self.session, err = sql.Open(`postgres`, conn)
-	if err != nil {
+	if self.session, err = sql.Open(`postgres`, conn); err != nil {
+		return err
+	}
+
+	if err = self.populateSchema(); err != nil {
 		return err
 	}
 
@@ -343,40 +366,74 @@ func (self *source) Drop() error {
 	return err
 }
 
-// Returns a list of all tables within the currently active database.
-func (self *source) Collections() ([]string, error) {
-	var collections []string
-	var collection string
+// Collections() Returns a list of non-system tables/collections contained
+// within the currently active database.
+func (self *source) Collections() (collections []string, err error) {
+
+	var tablesInSchema int = len(self.schema.Tables)
+
+	// Is schema already populated?
+	if tablesInSchema > 0 {
+		// Pulling table names from schema.
+		return self.schema.Tables, nil
+	}
+
+	// Schema is empty.
 
-	rows, err := self.doQuery(sqlgen.Statement{
+	// Querying table names.
+	stmt := sqlgen.Statement{
 		Type: sqlgen.SqlSelect,
 		Columns: sqlgen.Columns{
-			{"table_name"},
+			{`table_name`},
+		},
+		Table: sqlgen.Table{
+			`information_schema.tables`,
 		},
-		Table: sqlgen.Table{"information_schema.tables"},
 		Where: sqlgen.Where{
-			sqlgen.ColumnValue{sqlgen.Column{"table_schema"}, "=", sqlgen.Value{"public"}},
+			sqlgen.ColumnValue{
+				sqlgen.Column{`table_schema`},
+				`=`,
+				sqlgen.Value{`public`},
+			},
 		},
-	})
+	}
 
-	if err != nil {
+	// Executing statement.
+	var rows *sql.Rows
+	if rows, err = self.doQuery(stmt); err != nil {
 		return nil, err
 	}
 
 	defer rows.Close()
 
+	collections = []string{}
+
+	var name string
+
 	for rows.Next() {
-		rows.Scan(&collection)
-		collections = append(collections, collection)
+		// Getting table name.
+		if err = rows.Scan(&name); err != nil {
+			return nil, err
+		}
+
+		// Adding table entry to schema.
+		self.schema.AddTable(name)
+
+		// Adding table to collections array.
+		collections = append(collections, name)
 	}
 
 	return collections, nil
 }
 
 func (self *source) tableExists(names ...string) error {
-	for _, name := range names {
+	var stmt sqlgen.Statement
+	var err error
+	var rows *sql.Rows
 
-		rows, err := self.doQuery(sqlgen.Statement{
+	for i := range names {
+
+		stmt = sqlgen.Statement{
 			Type:  sqlgen.SqlSelect,
 			Table: sqlgen.Table{`information_schema.tables`},
 			Columns: sqlgen.Columns{
@@ -386,9 +443,9 @@ func (self *source) tableExists(names ...string) error {
 				sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
 				sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
 			},
-		}, self.config.Database, name)
+		}
 
-		if err != nil {
+		if rows, err = self.doQuery(stmt, self.config.Database, names[i]); err != nil {
 			return db.ErrCollectionDoesNotExist
 		}
 
@@ -402,9 +459,62 @@ func (self *source) tableExists(names ...string) error {
 	return nil
 }
 
+func (self *source) tableColumns(tableName string) ([]string, error) {
+
+	// Making sure this table is allocated.
+	tableSchema := self.schema.Table(tableName)
+
+	if len(tableSchema.Columns) > 0 {
+		return tableSchema.Columns, nil
+	}
+
+	stmt := sqlgen.Statement{
+		Type: sqlgen.SqlSelect,
+		Table: sqlgen.Table{
+			`information_schema.columns`,
+		},
+		Columns: sqlgen.Columns{
+			{`column_name`},
+			{`data_type`},
+		},
+		Where: sqlgen.Where{
+			sqlgen.ColumnValue{
+				sqlgen.Column{`table_catalog`},
+				`=`,
+				sqlPlaceholder,
+			},
+			sqlgen.ColumnValue{
+				sqlgen.Column{`table_name`},
+				`=`,
+				sqlPlaceholder,
+			},
+		},
+	}
+
+	var rows *sql.Rows
+	var err error
+
+	if rows, err = self.doQuery(stmt, self.config.Database, tableName); err != nil {
+		return nil, err
+	}
+
+	tableFields := []columnSchemaT{}
+
+	if err = sqlutil.FetchRows(rows, &tableFields); err != nil {
+		return nil, err
+	}
+
+	self.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields))
+
+	for i := range tableFields {
+		self.schema.TableInfo[tableName].Columns = append(self.schema.TableInfo[tableName].Columns, tableFields[i].Name)
+	}
+
+	return self.schema.TableInfo[tableName].Columns, nil
+}
+
 // Returns a collection instance by name.
 func (self *source) Collection(names ...string) (db.Collection, error) {
-	var rows *sql.Rows
 	var err error
 
 	if len(names) == 0 {
@@ -416,46 +526,21 @@ func (self *source) Collection(names ...string) (db.Collection, error) {
 		names:  names,
 	}
 
-	columns_t := []columnSchema_t{}
-
 	for _, name := range names {
-		chunks := strings.SplitN(name, " ", 2)
-
-		if len(chunks) > 0 {
-
-			name = chunks[0]
-
-			if err := self.tableExists(name); err != nil {
-				return nil, err
-			}
-
-			// Getting columns
-			rows, err = self.doQuery(sqlgen.Statement{
-				Type:  sqlgen.SqlSelect,
-				Table: sqlgen.Table{`information_schema.columns`},
-				Columns: sqlgen.Columns{
-					{`column_name`},
-					{`data_type`},
-				},
-				Where: sqlgen.Where{
-					sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
-					sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
-				},
-			}, self.config.Database, name)
-
-			if err != nil {
-				return nil, err
-			}
-
-			if err = sqlutil.FetchRows(rows, &columns_t); err != nil {
-				return nil, err
-			}
-
-			col.Columns = make([]string, 0, len(columns_t))
-
-			for _, column := range columns_t {
-				col.Columns = append(col.Columns, column.Name)
-			}
+		chunks := strings.SplitN(name, ` `, 2)
+
+		if len(chunks) == 0 {
+			return nil, db.ErrMissingCollectionName
+		}
+
+		tableName := chunks[0]
+
+		if err := self.tableExists(tableName); err != nil {
+			return nil, err
+		}
+
+		if col.Columns, err = self.tableColumns(tableName); err != nil {
+			return nil, err
 		}
 	}
 
@@ -463,21 +548,15 @@ func (self *source) Collection(names ...string) (db.Collection, error) {
 }
 
 func (self *source) getPrimaryKey(tableName string) (string, error) {
-	var row *sql.Row
-	var err error
-	var pKey string
 
-	if self.primaryKeys == nil {
-		self.primaryKeys = make(map[string]string)
-	}
+	tableSchema := self.schema.Table(tableName)
 
-	if pKey, ok := self.primaryKeys[tableName]; ok {
-		// Retrieving cached key name.
-		return pKey, nil
+	if tableSchema.PrimaryKey != "" {
+		return tableSchema.PrimaryKey, nil
 	}
 
 	// Getting primary key. See https://github.com/upper/db/issues/24.
-	row, err = self.doQueryRow(sqlgen.Statement{
+	stmt := sqlgen.Statement{
 		Type:  sqlgen.SqlSelect,
 		Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`},
 		Columns: sqlgen.Columns{
@@ -491,20 +570,18 @@ func (self *source) getPrimaryKey(tableName string) (string, error) {
 			sqlgen.Raw{`indisprimary`},
 		},
 		Limit: 1,
-	})
+	}
 
-	if err != nil {
+	var row *sql.Row
+	var err error
+
+	if row, err = self.doQueryRow(stmt); err != nil {
 		return "", err
 	}
 
-	if err = row.Scan(&pKey); err != nil {
+	if err = row.Scan(&tableSchema.PrimaryKey); err != nil {
 		return "", err
 	}
 
-	// Caching key name.
-	// TODO: There is currently no policy for cache life and no cache-cleaning
-	// methods are provided.
-	self.primaryKeys[tableName] = pKey
-
-	return pKey, nil
+	return tableSchema.PrimaryKey, nil
 }
-- 
GitLab