diff --git a/postgresql/collection.go b/postgresql/collection.go index 3f71764f10c449e4786ce484af03795cbc713bc5..abdaeb550421d6f5bee23fa34beb9bab25d0a04d 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -35,17 +35,14 @@ import ( type table struct { sqlutil.T - *source + *database primaryKey string names []string } -// Find creates a result set with the given conditions. -func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) -} +var _ = db.Collection(&table{}) +// tableN returns the nth name provided to the table. func (t *table) tableN(i int) string { if len(t.names) > i { chunks := strings.SplitN(t.names[i], " ", 2) @@ -56,9 +53,15 @@ func (t *table) tableN(i int) string { return "" } -// Truncate deletes all rows within the table. +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := sqlutil.ToWhereWithArguments(terms) + return result.NewResult(t, where, arguments) +} + +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - _, err := t.source.Exec(sqlgen.Statement{ + _, err := t.database.Exec(sqlgen.Statement{ Type: sqlgen.Truncate, Table: sqlgen.TableWithName(t.tableN(0)), }) @@ -109,7 +112,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - if pKey, err = t.source.getPrimaryKey(t.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.tableN(0)); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -127,7 +130,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { if len(pKey) == 0 { var res sql.Result - if res, err = t.source.Exec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, arguments...); err != nil { return nil, err } @@ -142,7 +145,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { // A primary key was found. stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) - if rows, err = t.source.Query(stmt, arguments...); err != nil { + if rows, err = t.database.Query(stmt, arguments...); err != nil { return nil, err } @@ -191,7 +194,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { // Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.names...); err != nil { return false } return true diff --git a/postgresql/database.go b/postgresql/database.go index f4d51bc185e1a414a438cc3a6d7640e1a46344e7..fc5e5dd75c81bf817a19c608eceda2cbeb8d2421 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -42,7 +42,7 @@ var ( sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB tx *sqltx.Tx @@ -51,9 +51,14 @@ type source struct { type tx struct { *sqltx.Tx - *source + *database } +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) + type columnSchemaT struct { Name string `db:"column_name"` DataType string `db:"data_type"` @@ -74,12 +79,12 @@ func debugLog(query string, args []interface{}, err error, start int64, end int6 } // Driver returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { +func (s *database) Driver() interface{} { return s.session } -// Open attempts to connect to the PostgreSQL server using the stored settings. -func (s *source) Open() error { +// Open attempts to connect to the PostgreSQL server using already stored settings. +func (s *database) Open() error { var err error // Before db.ConnectionURL we used a unified db.Settings struct. This @@ -112,13 +117,14 @@ func (s *source) Open() error { return nil } -// Clone returns a cloned db.Database session. -func (s *source) Clone() (db.Database, error) { +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (s *database) Clone() (db.Database, error) { return s.clone() } -func (s *source) clone() (*source, error) { - src := new(source) +func (s *database) clone() (*database, error) { + src := new(database) src.Setup(s.connURL) if err := src.Open(); err != nil { @@ -130,12 +136,12 @@ func (s *source) clone() (*source, error) { // Ping checks whether a connection to the database is still alive by pinging // it, establishing a connection if necessary. -func (s *source) Ping() error { +func (s *database) Ping() error { return s.session.Ping() } // Close terminates the current database session. -func (s *source) Close() error { +func (s *database) Close() error { if s.session != nil { return s.session.Close() } @@ -143,7 +149,7 @@ func (s *source) Close() error { } // Collection returns a table by name. -func (s *source) Collection(names ...string) (db.Collection, error) { +func (s *database) Collection(names ...string) (db.Collection, error) { var err error if len(names) == 0 { @@ -157,8 +163,8 @@ func (s *source) Collection(names ...string) (db.Collection, error) { } col := &table{ - source: s, - names: names, + database: s, + names: names, } for _, name := range names { @@ -182,8 +188,8 @@ func (s *source) Collection(names ...string) (db.Collection, error) { return col, nil } -// Collections returns a list of non-system tables within the database. -func (s *source) Collections() (collections []string, err error) { +// Collections returns a list of non-system tables from the database. +func (s *database) Collections() (collections []string, err error) { tablesInSchema := len(s.schema.Tables) @@ -240,7 +246,7 @@ func (s *source) Collections() (collections []string, err error) { } // Use changes the active database. -func (s *source) Use(database string) (err error) { +func (s *database) Use(database string) (err error) { var conn ConnectionURL if conn, err = ParseURL(s.connURL.String()); err != nil { @@ -254,8 +260,8 @@ func (s *source) Use(database string) (err error) { return s.Open() } -// Drop removes all tables within the current database. -func (s *source) Drop() error { +// Drop removes all tables from the current database. +func (s *database) Drop() error { _, err := s.Query(sqlgen.Statement{ Type: sqlgen.DropDatabase, Database: sqlgen.DatabaseWithName(s.schema.Name), @@ -264,21 +270,21 @@ func (s *source) Drop() error { } // Setup stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { +func (s *database) Setup(connURL db.ConnectionURL) error { s.connURL = connURL return s.Open() } // Name returns the name of the database. -func (s *source) Name() string { +func (s *database) Name() string { return s.schema.Name } // Transaction starts a transaction block and returns a db.Tx struct that can // be used to issue transactional queries. -func (s *source) Transaction() (db.Tx, error) { +func (s *database) Transaction() (db.Tx, error) { var err error - var clone *source + var clone *database var sqlTx *sqlx.Tx if sqlTx, err = s.session.Beginx(); err != nil { @@ -291,10 +297,11 @@ func (s *source) Transaction() (db.Tx, error) { clone.tx = sqltx.New(sqlTx) - return tx{Tx: clone.tx, source: clone}, nil + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (s *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -327,7 +334,8 @@ func (s *source) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, e return res, err } -func (s *source) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (s *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -360,7 +368,8 @@ func (s *source) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, return rows, err } -func (s *source) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (s *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -393,7 +402,9 @@ func (s *source) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row return row, err } -func (s *source) populateSchema() (err error) { +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (s *database) populateSchema() (err error) { var collections []string s.schema = schema.NewDatabaseSchema() @@ -416,13 +427,11 @@ func (s *source) populateSchema() (err error) { return err } - // The Collections() call will populate schema if its nil. if collections, err = s.Collections(); err != nil { return err } for i := range collections { - // Populate each collection. if _, err = s.Collection(collections[i]); err != nil { return err } @@ -431,7 +440,7 @@ func (s *source) populateSchema() (err error) { return err } -func (s *source) tableExists(names ...string) error { +func (s *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows @@ -469,7 +478,7 @@ func (s *source) tableExists(names ...string) error { defer rows.Close() - if rows.Next() == false { + if !rows.Next() { return db.ErrCollectionDoesNotExist } } @@ -477,7 +486,7 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (s *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. tableSchema := s.schema.Table(tableName) @@ -531,7 +540,7 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return s.schema.TableInfo[tableName].Columns, nil } -func (s *source) getPrimaryKey(tableName string) ([]string, error) { +func (s *database) getPrimaryKey(tableName string) ([]string, error) { tableSchema := s.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { @@ -569,6 +578,8 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { return nil, err } + defer rows.Close() + tableSchema.PrimaryKey = make([]string, 0, 1) for rows.Next() { diff --git a/postgresql/database_test.go b/postgresql/database_test.go index 867185e864530a9809218e15bf5c5ef21f32d67b..0e7e9350d3815fa41b786317b388e7076eb730d4 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -40,9 +40,9 @@ import ( ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" ) const ( @@ -50,7 +50,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, Options: map[string]string{ @@ -181,7 +181,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: password, @@ -195,7 +195,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: "fail", @@ -219,7 +219,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: "fail", Password: password, @@ -236,7 +236,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, Host: host, diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index 8aa0d45d0ef900a76c2545fcba5140c55b8863cb..f7fac8b1015eeb45d745640df141a62bf3c18a8b 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -30,45 +30,40 @@ import ( // Adapter is the public name of the adapter. const Adapter = `postgresql` -var ( - _ = db.Database(&source{}) - _ = db.Collection(&table{}) -) - var template *sqlgen.Template func init() { template = &sqlgen.Template{ - ColumnSeparator: pgsqlColumnSeparator, - IdentifierSeparator: pgsqlIdentifierSeparator, - IdentifierQuote: pgsqlIdentifierQuote, - ValueSeparator: pgsqlValueSeparator, - ValueQuote: pgsqlValueQuote, - AndKeyword: pgsqlAndKeyword, - OrKeyword: pgsqlOrKeyword, - NotKeyword: pgsqlNotKeyword, - DescKeyword: pgsqlDescKeyword, - AscKeyword: pgsqlAscKeyword, - DefaultOperator: pgsqlDefaultOperator, - ClauseGroup: pgsqlClauseGroup, - ClauseOperator: pgsqlClauseOperator, - ColumnValue: pgsqlColumnValue, - TableAliasLayout: pgsqlTableAliasLayout, - ColumnAliasLayout: pgsqlColumnAliasLayout, - SortByColumnLayout: pgsqlSortByColumnLayout, - WhereLayout: pgsqlWhereLayout, - OrderByLayout: pgsqlOrderByLayout, - InsertLayout: pgsqlInsertLayout, - SelectLayout: pgsqlSelectLayout, - UpdateLayout: pgsqlUpdateLayout, - DeleteLayout: pgsqlDeleteLayout, - TruncateLayout: pgsqlTruncateLayout, - DropDatabaseLayout: pgsqlDropDatabaseLayout, - DropTableLayout: pgsqlDropTableLayout, - CountLayout: pgsqlSelectCountLayout, - GroupByLayout: pgsqlGroupByLayout, + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), } - db.Register(Adapter, &source{}) + db.Register(Adapter, &database{}) } diff --git a/postgresql/template.go b/postgresql/template.go index 930e6c8bc58fc643b58547d2ad0a119289607d04..7c8f13a4c9a39af67f1e1af40931c70be70db0af 100644 --- a/postgresql/template.go +++ b/postgresql/template.go @@ -22,37 +22,37 @@ package postgresql const ( - pgsqlColumnSeparator = `.` - pgsqlIdentifierSeparator = `, ` - pgsqlIdentifierQuote = `"{{.Value}}"` - pgsqlValueSeparator = `, ` - pgsqlValueQuote = `'{{.}}'` - pgsqlAndKeyword = `AND` - pgsqlOrKeyword = `OR` - pgsqlNotKeyword = `NOT` - pgsqlDescKeyword = `DESC` - pgsqlAscKeyword = `ASC` - pgsqlDefaultOperator = `=` - pgsqlClauseGroup = `({{.}})` - pgsqlClauseOperator = ` {{.}} ` - pgsqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - pgsqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlSortByColumnLayout = `{{.Column}} {{.Order}}` - - pgsqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - pgsqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - pgsqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -79,19 +79,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - pgsqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - pgsqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - pgsqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -106,7 +106,7 @@ const ( {{end}} ` - pgsqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -114,23 +114,21 @@ const ( {{.Extra}} ` - pgsqlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} RESTART IDENTITY ` - pgsqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - pgsqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - pgsqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - psqlNull = `NULL` )