Newer
Older
// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import (
"database/sql"
"fmt"
"regexp"
"strings"
_ "github.com/xiam/gopostgresql"
"upper.io/db/util/sqlutil"
const Adapter = `postgresql`
var (
// Format for saving dates.
DateFormat = "2006-01-02 15:04:05"
// Format for saving times.
TimeFormat = "%d:%02d:%02d.%d"
SSLMode = "disable"
)
var (
columnPattern = regexp.MustCompile(`^([a-z]+)\(?([0-9,]+)?\)?\s?([a-z]*)?`)
sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}}
)
config db.Settings
session *sql.DB
collections map[string]db.Collection
primaryKeys map[string]string
type columnSchema_t struct {
Name string `db:"column_name"`
func debugEnabled() bool {
if os.Getenv(db.EnvEnableDebug) != "" {
return true
}
return false
}
func debugLog(query string, args []interface{}, err error) {
if debugEnabled() == true {
d := sqlutil.Debug{query, args, err}
d.Print()
}
}
pgsqlColumnSeparator,
pgsqlIdentifierSeparator,
pgsqlIdentifierQuote,
pgsqlValueSeparator,
pgsqlValueQuote,
pgsqlAndKeyword,
pgsqlOrKeyword,
pgsqlNotKeyword,
pgsqlDescKeyword,
pgsqlAscKeyword,
pgsqlDefaultOperator,
pgsqlClauseGroup,
pgsqlClauseOperator,
pgsqlColumnValue,
pgsqlTableAliasLayout,
pgsqlColumnAliasLayout,
pgsqlSortByColumnLayout,
pgsqlWhereLayout,
pgsqlOrderByLayout,
pgsqlInsertLayout,
pgsqlSelectLayout,
pgsqlUpdateLayout,
pgsqlDeleteLayout,
pgsqlTruncateLayout,
pgsqlDropDatabaseLayout,
pgsqlDropTableLayout,
pgsqlSelectCountLayout,
db.Register(Adapter, &source{})
func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) {
var query string
var res sql.Result
var err error
defer func() {
debugLog(query, args, err)
}()
if self.session == nil {
l := len(args)
for i := 0; i < l; i++ {
query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1)
res, err = self.tx.Exec(query, args...)
} else {
res, err = self.session.Exec(query, args...)
func (self *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) {
var rows *sql.Rows
var query string
var err error
defer func() {
debugLog(query, args, err)
}()
if self.session == nil {
l := len(args)
for i := 0; i < l; i++ {
query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1)
}
rows, err = self.tx.Query(query, args...)
} else {
rows, err = self.session.Query(query, args...)
func (self *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) {
var row *sql.Row
var err error
defer func() {
debugLog(query, args, err)
}()
if self.session == nil {
l := len(args)
for i := 0; i < l; i++ {
query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1)
row = self.tx.QueryRow(query, args...)
} else {
row = self.session.QueryRow(query, args...)
return row, err
// Returns the string name of the database.
func (self *source) Name() string {
return self.config.Database
José Carlos Nieto
committed
}
// Ping verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (self *source) Ping() error {
return self.session.Ping()
}
func (self *source) clone() (*source, error) {
src := new(source)
src.Setup(self.config)
if err := src.Open(); err != nil {
return nil, err
}
return src, nil
}
func (self *source) Clone() (db.Database, error) {
return self.clone()
}
func (self *source) Transaction() (db.Tx, error) {
var sqlTx *sql.Tx
if sqlTx, err = self.session.Begin(); err != nil {
return nil, err
}
if clone, err = self.clone(); err != nil {
return nil, err
}
clone.tx = sqlTx
return tx, nil
}
// Stores database settings.
func (self *source) Setup(config db.Settings) error {
self.config = config
self.collections = make(map[string]db.Collection)
return self.Open()
}
// Returns the underlying *sql.DB instance.
func (self *source) Driver() interface{} {
return self.session
}
// Attempts to connect to a database using the stored settings.
func (self *source) Open() error {
var err error
if self.config.Host == "" {
self.config.Host = `127.0.0.1`
if self.config.Port == 0 {
self.config.Port = 5432
if self.config.Database == "" {
if self.config.Socket != "" && self.config.Host != "" {
if user := self.config.User; user != "" {
conn += fmt.Sprintf(`user=%s `, user)
}
if pass := self.config.Password; pass != "" {
conn += fmt.Sprintf(`password=%s `, pass)
}
conn += fmt.Sprintf(`host=%s port=%d `, self.config.Host, self.config.Port)
} else {
conn += fmt.Sprintf(`host=%s `, self.config.Socket)
conn += fmt.Sprintf(`dbname=%s sslmode=%s`, self.config.Database, SSLMode)
self.session, err = sql.Open(`postgres`, conn)
if err != nil {
}
return nil
}
// Closes the current database session.
func (self *source) Close() error {
if self.session != nil {
return self.session.Close()
}
return nil
}
// Changes the active database.
func (self *source) Use(database string) error {
self.config.Database = database
return self.Open()
// Drops the currently active database.
func (self *source) Drop() error {
_, err := self.doQuery(sqlgen.Statement{
Type: sqlgen.SqlDropDatabase,
Database: sqlgen.Database{self.config.Database},
})
return err
// Returns a list of all tables within the currently active database.
func (self *source) Collections() ([]string, error) {
var collections []string
var collection string
rows, err := self.doQuery(sqlgen.Statement{
Type: sqlgen.SqlSelect,
Columns: sqlgen.Columns{
{"table_name"},
},
Table: sqlgen.Table{"information_schema.tables"},
Where: sqlgen.Where{
sqlgen.ColumnValue{sqlgen.Column{"table_schema"}, "=", sqlgen.Value{"public"}},
},
})
if err != nil {
return nil, err
for rows.Next() {
rows.Scan(&collection)
collections = append(collections, collection)
return collections, nil
func (self *source) tableExists(names ...string) error {
for _, name := range names {
rows, err := self.doQuery(sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`information_schema.tables`},
sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
}, self.config.Database, name)
// Returns a collection instance by name.
func (self *source) Collection(names ...string) (db.Collection, error) {
var rows *sql.Rows
var err error
if len(names) == 0 {
return nil, db.ErrMissingCollectionName
columns_t := []columnSchema_t{}
for _, name := range names {
chunks := strings.SplitN(name, " ", 2)
if err := self.tableExists(name); err != nil {
return nil, err
}
// Getting columns
rows, err = self.doQuery(sqlgen.Statement{
Table: sqlgen.Table{`information_schema.columns`},
{`column_name`},
{`data_type`},
sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder},
sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder},
}, self.config.Database, name)
if err != nil {
return nil, err
}
if err = sqlutil.FetchRows(rows, &columns_t); err != nil {
col.Columns = make([]string, 0, len(columns_t))
for _, column := range columns_t {
col.Columns = append(col.Columns, strings.ToLower(column.Name))
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
func (self *source) getPrimaryKey(tableName string) (string, error) {
var row *sql.Row
var err error
var pKey string
if self.primaryKeys == nil {
self.primaryKeys = make(map[string]string)
}
if pKey, ok := self.primaryKeys[tableName]; ok {
// Retrieving cached key name.
return pKey, nil
}
// Getting primary key. See https://github.com/upper/db/issues/24.
row, err = self.doQueryRow(sqlgen.Statement{
Type: sqlgen.SqlSelect,
Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`},
Columns: sqlgen.Columns{
{`pg_attribute.attname`},
},
Where: sqlgen.Where{
sqlgen.ColumnValue{sqlgen.Column{`pg_class.oid`}, `=`, sqlgen.Value{sqlgen.Raw{`?::regclass`}}},
sqlgen.ColumnValue{sqlgen.Column{`indrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}},
sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}},
sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attnum`}, `=`, sqlgen.Value{sqlgen.Raw{`any(pg_index.indkey)`}}},
sqlgen.Raw{`indisprimary`},
},
Limit: 1,
}, tableName)
if err != nil {
return "", err
}
if err = row.Scan(&pKey); err != nil {
return "", err
}
// Caching key name.
// TODO: There is currently no policy for cache life and no cache-cleaning
// methods are provided.
self.primaryKeys[tableName] = pKey
return pKey, nil
}