good morning!!!!

Skip to content
Snippets Groups Projects
Commit 4efce5d9 authored by José Carlos Nieto's avatar José Carlos Nieto
Browse files

PostgreSQL: Adding support for upper.io/db/util/schema.

parent 845ef6df
Branches
Tags
No related merge requests found
...@@ -25,7 +25,6 @@ import ( ...@@ -25,7 +25,6 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -33,6 +32,7 @@ import ( ...@@ -33,6 +32,7 @@ import (
_ "github.com/xiam/gopostgresql" _ "github.com/xiam/gopostgresql"
"upper.io/cache" "upper.io/cache"
"upper.io/db" "upper.io/db"
"upper.io/db/util/schema"
"upper.io/db/util/sqlgen" "upper.io/db/util/sqlgen"
"upper.io/db/util/sqlutil" "upper.io/db/util/sqlutil"
) )
...@@ -50,19 +50,17 @@ var ( ...@@ -50,19 +50,17 @@ var (
var template *sqlgen.Template var template *sqlgen.Template
var ( var (
columnPattern = regexp.MustCompile(`^([a-z]+)\(?([0-9,]+)?\)?\s?([a-z]*)?`)
sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}}
) )
type source struct { type source struct {
config db.Settings config db.Settings
session *sql.DB session *sql.DB
collections map[string]db.Collection
tx *sql.Tx tx *sql.Tx
primaryKeys map[string]string schema *schema.DatabaseSchema
} }
type columnSchema_t struct { type columnSchemaT struct {
Name string `db:"column_name"` Name string `db:"column_name"`
} }
...@@ -117,6 +115,28 @@ func init() { ...@@ -117,6 +115,28 @@ func init() {
db.Register(Adapter, &source{}) db.Register(Adapter, &source{})
} }
func (self *source) populateSchema() (err error) {
var collections []string
self.schema = schema.NewDatabaseSchema()
self.schema.Name = self.config.Database
// The Collections() call will populate schema if its nil.
if collections, err = self.Collections(); err != nil {
return err
}
for i := range collections {
// Populate each collection.
if _, err = self.Collection(collections[i]); err != nil {
return err
}
}
return err
}
func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) {
var query string var query string
var res sql.Result var res sql.Result
...@@ -265,7 +285,7 @@ func (self *source) Transaction() (db.Tx, error) { ...@@ -265,7 +285,7 @@ func (self *source) Transaction() (db.Tx, error) {
// Stores database settings. // Stores database settings.
func (self *source) Setup(config db.Settings) error { func (self *source) Setup(config db.Settings) error {
self.config = config self.config = config
self.collections = make(map[string]db.Collection) //self.collections = make(map[string]db.Collection)
return self.Open() return self.Open()
} }
...@@ -310,8 +330,11 @@ func (self *source) Open() error { ...@@ -310,8 +330,11 @@ func (self *source) Open() error {
} }
conn += fmt.Sprintf(`dbname=%s sslmode=%s`, self.config.Database, SSLMode) conn += fmt.Sprintf(`dbname=%s sslmode=%s`, self.config.Database, SSLMode)
self.session, err = sql.Open(`postgres`, conn) if self.session, err = sql.Open(`postgres`, conn); err != nil {
if err != nil { return err
}
if err = self.populateSchema(); err != nil {
return err return err
} }
...@@ -343,40 +366,74 @@ func (self *source) Drop() error { ...@@ -343,40 +366,74 @@ func (self *source) Drop() error {
return err return err
} }
// Returns a list of all tables within the currently active database. // Collections() Returns a list of non-system tables/collections contained
func (self *source) Collections() ([]string, error) { // within the currently active database.
var collections []string func (self *source) Collections() (collections []string, err error) {
var collection string
var tablesInSchema int = len(self.schema.Tables)
rows, err := self.doQuery(sqlgen.Statement{ // Is schema already populated?
if tablesInSchema > 0 {
// Pulling table names from schema.
return self.schema.Tables, nil
}
// Schema is empty.
// Querying table names.
stmt := sqlgen.Statement{
Type: sqlgen.SqlSelect, Type: sqlgen.SqlSelect,
Columns: sqlgen.Columns{ Columns: sqlgen.Columns{
{"table_name"}, {`table_name`},
},
Table: sqlgen.Table{
`information_schema.tables`,
}, },
Table: sqlgen.Table{"information_schema.tables"},
Where: sqlgen.Where{ Where: sqlgen.Where{
sqlgen.ColumnValue{sqlgen.Column{"table_schema"}, "=", sqlgen.Value{"public"}}, sqlgen.ColumnValue{
sqlgen.Column{`table_schema`},
`=`,
sqlgen.Value{`public`},
}, },
}) },
}
if err != nil { // Executing statement.
var rows *sql.Rows
if rows, err = self.doQuery(stmt); err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
collections = []string{}
var name string
for rows.Next() { for rows.Next() {
rows.Scan(&collection) // Getting table name.
collections = append(collections, collection) if err = rows.Scan(&name); err != nil {
return nil, err
}
// Adding table entry to schema.
self.schema.AddTable(name)
// Adding table to collections array.
collections = append(collections, name)
} }
return collections, nil return collections, nil
} }
func (self *source) tableExists(names ...string) error { func (self *source) tableExists(names ...string) error {
for _, name := range names { var stmt sqlgen.Statement
var err error
var rows *sql.Rows
rows, err := self.doQuery(sqlgen.Statement{ for i := range names {
stmt = sqlgen.Statement{
Type: sqlgen.SqlSelect, Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`information_schema.tables`}, Table: sqlgen.Table{`information_schema.tables`},
Columns: sqlgen.Columns{ Columns: sqlgen.Columns{
...@@ -386,9 +443,9 @@ func (self *source) tableExists(names ...string) error { ...@@ -386,9 +443,9 @@ func (self *source) tableExists(names ...string) error {
sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder}, sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
}, },
}, self.config.Database, name) }
if err != nil { if rows, err = self.doQuery(stmt, self.config.Database, names[i]); err != nil {
return db.ErrCollectionDoesNotExist return db.ErrCollectionDoesNotExist
} }
...@@ -402,9 +459,62 @@ func (self *source) tableExists(names ...string) error { ...@@ -402,9 +459,62 @@ func (self *source) tableExists(names ...string) error {
return nil return nil
} }
func (self *source) tableColumns(tableName string) ([]string, error) {
// Making sure this table is allocated.
tableSchema := self.schema.Table(tableName)
if len(tableSchema.Columns) > 0 {
return tableSchema.Columns, nil
}
stmt := sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{
`information_schema.columns`,
},
Columns: sqlgen.Columns{
{`column_name`},
{`data_type`},
},
Where: sqlgen.Where{
sqlgen.ColumnValue{
sqlgen.Column{`table_catalog`},
`=`,
sqlPlaceholder,
},
sqlgen.ColumnValue{
sqlgen.Column{`table_name`},
`=`,
sqlPlaceholder,
},
},
}
var rows *sql.Rows
var err error
if rows, err = self.doQuery(stmt, self.config.Database, tableName); err != nil {
return nil, err
}
tableFields := []columnSchemaT{}
if err = sqlutil.FetchRows(rows, &tableFields); err != nil {
return nil, err
}
self.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields))
for i := range tableFields {
self.schema.TableInfo[tableName].Columns = append(self.schema.TableInfo[tableName].Columns, tableFields[i].Name)
}
return self.schema.TableInfo[tableName].Columns, nil
}
// Returns a collection instance by name. // Returns a collection instance by name.
func (self *source) Collection(names ...string) (db.Collection, error) { func (self *source) Collection(names ...string) (db.Collection, error) {
var rows *sql.Rows
var err error var err error
if len(names) == 0 { if len(names) == 0 {
...@@ -416,68 +526,37 @@ func (self *source) Collection(names ...string) (db.Collection, error) { ...@@ -416,68 +526,37 @@ func (self *source) Collection(names ...string) (db.Collection, error) {
names: names, names: names,
} }
columns_t := []columnSchema_t{}
for _, name := range names { for _, name := range names {
chunks := strings.SplitN(name, " ", 2) chunks := strings.SplitN(name, ` `, 2)
if len(chunks) > 0 {
name = chunks[0]
if err := self.tableExists(name); err != nil { if len(chunks) == 0 {
return nil, err return nil, db.ErrMissingCollectionName
} }
// Getting columns tableName := chunks[0]
rows, err = self.doQuery(sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`information_schema.columns`},
Columns: sqlgen.Columns{
{`column_name`},
{`data_type`},
},
Where: sqlgen.Where{
sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
},
}, self.config.Database, name)
if err != nil { if err := self.tableExists(tableName); err != nil {
return nil, err return nil, err
} }
if err = sqlutil.FetchRows(rows, &columns_t); err != nil { if col.Columns, err = self.tableColumns(tableName); err != nil {
return nil, err return nil, err
} }
col.Columns = make([]string, 0, len(columns_t))
for _, column := range columns_t {
col.Columns = append(col.Columns, column.Name)
}
}
} }
return col, nil return col, nil
} }
func (self *source) getPrimaryKey(tableName string) (string, error) { func (self *source) getPrimaryKey(tableName string) (string, error) {
var row *sql.Row
var err error
var pKey string
if self.primaryKeys == nil { tableSchema := self.schema.Table(tableName)
self.primaryKeys = make(map[string]string)
}
if pKey, ok := self.primaryKeys[tableName]; ok { if tableSchema.PrimaryKey != "" {
// Retrieving cached key name. return tableSchema.PrimaryKey, nil
return pKey, nil
} }
// Getting primary key. See https://github.com/upper/db/issues/24. // Getting primary key. See https://github.com/upper/db/issues/24.
row, err = self.doQueryRow(sqlgen.Statement{ stmt := sqlgen.Statement{
Type: sqlgen.SqlSelect, Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`}, Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`},
Columns: sqlgen.Columns{ Columns: sqlgen.Columns{
...@@ -491,20 +570,18 @@ func (self *source) getPrimaryKey(tableName string) (string, error) { ...@@ -491,20 +570,18 @@ func (self *source) getPrimaryKey(tableName string) (string, error) {
sqlgen.Raw{`indisprimary`}, sqlgen.Raw{`indisprimary`},
}, },
Limit: 1, Limit: 1,
}) }
var row *sql.Row
var err error
if err != nil { if row, err = self.doQueryRow(stmt); err != nil {
return "", err return "", err
} }
if err = row.Scan(&pKey); err != nil { if err = row.Scan(&tableSchema.PrimaryKey); err != nil {
return "", err return "", err
} }
// Caching key name. return tableSchema.PrimaryKey, nil
// TODO: There is currently no policy for cache life and no cache-cleaning
// methods are provided.
self.primaryKeys[tableName] = pKey
return pKey, nil
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment