good morning!!!!

Skip to content
Snippets Groups Projects
Commit 4efce5d9 authored by José Carlos Nieto's avatar José Carlos Nieto
Browse files

PostgreSQL: Adding support for upper.io/db/util/schema.

parent 845ef6df
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,6 @@ import (
"database/sql"
"fmt"
"os"
"regexp"
"strconv"
"strings"
"time"
......@@ -33,6 +32,7 @@ import (
_ "github.com/xiam/gopostgresql"
"upper.io/cache"
"upper.io/db"
"upper.io/db/util/schema"
"upper.io/db/util/sqlgen"
"upper.io/db/util/sqlutil"
)
......@@ -50,19 +50,17 @@ var (
var template *sqlgen.Template
var (
columnPattern = regexp.MustCompile(`^([a-z]+)\(?([0-9,]+)?\)?\s?([a-z]*)?`)
sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}}
)
type source struct {
config db.Settings
session *sql.DB
collections map[string]db.Collection
tx *sql.Tx
primaryKeys map[string]string
config db.Settings
session *sql.DB
tx *sql.Tx
schema *schema.DatabaseSchema
}
type columnSchema_t struct {
type columnSchemaT struct {
Name string `db:"column_name"`
}
......@@ -117,6 +115,28 @@ func init() {
db.Register(Adapter, &source{})
}
func (self *source) populateSchema() (err error) {
var collections []string
self.schema = schema.NewDatabaseSchema()
self.schema.Name = self.config.Database
// The Collections() call will populate schema if its nil.
if collections, err = self.Collections(); err != nil {
return err
}
for i := range collections {
// Populate each collection.
if _, err = self.Collection(collections[i]); err != nil {
return err
}
}
return err
}
func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) {
var query string
var res sql.Result
......@@ -265,7 +285,7 @@ func (self *source) Transaction() (db.Tx, error) {
// Stores database settings.
func (self *source) Setup(config db.Settings) error {
self.config = config
self.collections = make(map[string]db.Collection)
//self.collections = make(map[string]db.Collection)
return self.Open()
}
......@@ -310,8 +330,11 @@ func (self *source) Open() error {
}
conn += fmt.Sprintf(`dbname=%s sslmode=%s`, self.config.Database, SSLMode)
self.session, err = sql.Open(`postgres`, conn)
if err != nil {
if self.session, err = sql.Open(`postgres`, conn); err != nil {
return err
}
if err = self.populateSchema(); err != nil {
return err
}
......@@ -343,40 +366,74 @@ func (self *source) Drop() error {
return err
}
// Returns a list of all tables within the currently active database.
func (self *source) Collections() ([]string, error) {
var collections []string
var collection string
// Collections() Returns a list of non-system tables/collections contained
// within the currently active database.
func (self *source) Collections() (collections []string, err error) {
var tablesInSchema int = len(self.schema.Tables)
// Is schema already populated?
if tablesInSchema > 0 {
// Pulling table names from schema.
return self.schema.Tables, nil
}
// Schema is empty.
rows, err := self.doQuery(sqlgen.Statement{
// Querying table names.
stmt := sqlgen.Statement{
Type: sqlgen.SqlSelect,
Columns: sqlgen.Columns{
{"table_name"},
{`table_name`},
},
Table: sqlgen.Table{
`information_schema.tables`,
},
Table: sqlgen.Table{"information_schema.tables"},
Where: sqlgen.Where{
sqlgen.ColumnValue{sqlgen.Column{"table_schema"}, "=", sqlgen.Value{"public"}},
sqlgen.ColumnValue{
sqlgen.Column{`table_schema`},
`=`,
sqlgen.Value{`public`},
},
},
})
}
if err != nil {
// Executing statement.
var rows *sql.Rows
if rows, err = self.doQuery(stmt); err != nil {
return nil, err
}
defer rows.Close()
collections = []string{}
var name string
for rows.Next() {
rows.Scan(&collection)
collections = append(collections, collection)
// Getting table name.
if err = rows.Scan(&name); err != nil {
return nil, err
}
// Adding table entry to schema.
self.schema.AddTable(name)
// Adding table to collections array.
collections = append(collections, name)
}
return collections, nil
}
func (self *source) tableExists(names ...string) error {
for _, name := range names {
var stmt sqlgen.Statement
var err error
var rows *sql.Rows
rows, err := self.doQuery(sqlgen.Statement{
for i := range names {
stmt = sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`information_schema.tables`},
Columns: sqlgen.Columns{
......@@ -386,9 +443,9 @@ func (self *source) tableExists(names ...string) error {
sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
},
}, self.config.Database, name)
}
if err != nil {
if rows, err = self.doQuery(stmt, self.config.Database, names[i]); err != nil {
return db.ErrCollectionDoesNotExist
}
......@@ -402,9 +459,62 @@ func (self *source) tableExists(names ...string) error {
return nil
}
func (self *source) tableColumns(tableName string) ([]string, error) {
// Making sure this table is allocated.
tableSchema := self.schema.Table(tableName)
if len(tableSchema.Columns) > 0 {
return tableSchema.Columns, nil
}
stmt := sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{
`information_schema.columns`,
},
Columns: sqlgen.Columns{
{`column_name`},
{`data_type`},
},
Where: sqlgen.Where{
sqlgen.ColumnValue{
sqlgen.Column{`table_catalog`},
`=`,
sqlPlaceholder,
},
sqlgen.ColumnValue{
sqlgen.Column{`table_name`},
`=`,
sqlPlaceholder,
},
},
}
var rows *sql.Rows
var err error
if rows, err = self.doQuery(stmt, self.config.Database, tableName); err != nil {
return nil, err
}
tableFields := []columnSchemaT{}
if err = sqlutil.FetchRows(rows, &tableFields); err != nil {
return nil, err
}
self.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields))
for i := range tableFields {
self.schema.TableInfo[tableName].Columns = append(self.schema.TableInfo[tableName].Columns, tableFields[i].Name)
}
return self.schema.TableInfo[tableName].Columns, nil
}
// Returns a collection instance by name.
func (self *source) Collection(names ...string) (db.Collection, error) {
var rows *sql.Rows
var err error
if len(names) == 0 {
......@@ -416,46 +526,21 @@ func (self *source) Collection(names ...string) (db.Collection, error) {
names: names,
}
columns_t := []columnSchema_t{}
for _, name := range names {
chunks := strings.SplitN(name, " ", 2)
if len(chunks) > 0 {
name = chunks[0]
if err := self.tableExists(name); err != nil {
return nil, err
}
// Getting columns
rows, err = self.doQuery(sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`information_schema.columns`},
Columns: sqlgen.Columns{
{`column_name`},
{`data_type`},
},
Where: sqlgen.Where{
sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
},
}, self.config.Database, name)
if err != nil {
return nil, err
}
if err = sqlutil.FetchRows(rows, &columns_t); err != nil {
return nil, err
}
col.Columns = make([]string, 0, len(columns_t))
for _, column := range columns_t {
col.Columns = append(col.Columns, column.Name)
}
chunks := strings.SplitN(name, ` `, 2)
if len(chunks) == 0 {
return nil, db.ErrMissingCollectionName
}
tableName := chunks[0]
if err := self.tableExists(tableName); err != nil {
return nil, err
}
if col.Columns, err = self.tableColumns(tableName); err != nil {
return nil, err
}
}
......@@ -463,21 +548,15 @@ func (self *source) Collection(names ...string) (db.Collection, error) {
}
func (self *source) getPrimaryKey(tableName string) (string, error) {
var row *sql.Row
var err error
var pKey string
if self.primaryKeys == nil {
self.primaryKeys = make(map[string]string)
}
tableSchema := self.schema.Table(tableName)
if pKey, ok := self.primaryKeys[tableName]; ok {
// Retrieving cached key name.
return pKey, nil
if tableSchema.PrimaryKey != "" {
return tableSchema.PrimaryKey, nil
}
// Getting primary key. See https://github.com/upper/db/issues/24.
row, err = self.doQueryRow(sqlgen.Statement{
stmt := sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`},
Columns: sqlgen.Columns{
......@@ -491,20 +570,18 @@ func (self *source) getPrimaryKey(tableName string) (string, error) {
sqlgen.Raw{`indisprimary`},
},
Limit: 1,
})
}
if err != nil {
var row *sql.Row
var err error
if row, err = self.doQueryRow(stmt); err != nil {
return "", err
}
if err = row.Scan(&pKey); err != nil {
if err = row.Scan(&tableSchema.PrimaryKey); err != nil {
return "", err
}
// Caching key name.
// TODO: There is currently no policy for cache life and no cache-cleaning
// methods are provided.
self.primaryKeys[tableName] = pKey
return pKey, nil
return tableSchema.PrimaryKey, nil
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment