diff --git a/builder.go b/builder.go index a112e295e3ee5c5cd382e19e9f11786dcaf58a07..3dc52af6984d326f0d52a9e56467407447a9d9b7 100644 --- a/builder.go +++ b/builder.go @@ -43,8 +43,13 @@ type QuerySelector interface { type QueryInserter interface { Values(...interface{}) QueryInserter Columns(...string) QueryInserter + Extra(string) QueryInserter + + Iterator() Iterator QueryExecer + QueryGetter + fmt.Stringer } @@ -81,3 +86,9 @@ type Iterator interface { Err() error Close() error } + +type QueryTruncater interface { + Extra(s string) QueryTruncater + + fmt.Stringer +} diff --git a/builder/builder.go b/builder/builder.go index cb9ed69aa8a640465a2ba948e8930352888d58b5..a895c84f79aaa76857a457068840cd5970951a28 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -63,6 +63,16 @@ type Builder struct { t *sqlutil.TemplateWithUtils } +func (b *Builder) TruncateTable(table string) db.QueryTruncater { + qs := &QueryTruncater{ + builder: b, + table: table, + } + + qs.stringer = &stringer{qs, b.t.Template} + return qs +} + func (b *Builder) SelectAllFrom(table string) db.QuerySelector { qs := &QuerySelector{ builder: b, @@ -120,12 +130,31 @@ type QueryInserter struct { values []*sqlgen.Values columns []sqlgen.Fragment arguments []interface{} + extra string +} + +func (qi *QueryInserter) Extra(s string) db.QueryInserter { + qi.extra = s + return qi } func (qi *QueryInserter) Exec() (sql.Result, error) { return qi.builder.sess.Exec(qi.statement()) } +func (qi *QueryInserter) Query() (*sqlx.Rows, error) { + return qi.builder.sess.Query(qi.statement(), qi.arguments...) +} + +func (qi *QueryInserter) QueryRow() (*sqlx.Row, error) { + return qi.builder.sess.QueryRow(qi.statement(), qi.arguments...) +} + +func (qi *QueryInserter) Iterator() db.Iterator { + rows, err := qi.builder.sess.Query(qi.statement(), qi.arguments...) + return &iterator{rows, err} +} + func (qi *QueryInserter) Columns(columns ...string) db.QueryInserter { l := len(columns) f := make([]sqlgen.Fragment, l) @@ -166,6 +195,7 @@ func (qi *QueryInserter) statement() *sqlgen.Statement { stmt := &sqlgen.Statement{ Type: sqlgen.Insert, Table: sqlgen.TableWithName(qi.table), + Extra: sqlgen.Extra(qi.extra), } if len(qi.values) > 0 { @@ -498,6 +528,30 @@ func (qs *QuerySelector) Iterator() db.Iterator { return &iterator{rows, err} } +type QueryTruncater struct { + *stringer + builder *Builder + table string + extra string + err error +} + +func (qt *QueryTruncater) Extra(extra string) db.QueryTruncater { + qt.extra = extra + return qt +} + +func (qt *QueryTruncater) statement() *sqlgen.Statement { + + stmt := &sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(qt.table), + Extra: sqlgen.Extra(qt.extra), + } + + return stmt +} + func columnFragments(template *sqlutil.TemplateWithUtils, columns []interface{}) ([]sqlgen.Fragment, error) { l := len(columns) f := make([]sqlgen.Fragment, l) diff --git a/builder/builder_test.go b/builder/builder_test.go index 99f43870f006827edb0bc693b03c1efc4de3e980..aa510156bfbdd5388f84970b054b7204f31e1e4d 100644 --- a/builder/builder_test.go +++ b/builder/builder_test.go @@ -144,6 +144,11 @@ func TestInsert(t *testing.T) { b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), ) + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Extra(`RETURNING "id"`).String(), + ) + assert.Equal( `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), @@ -219,3 +224,13 @@ func TestDelete(t *testing.T) { b.DeleteFrom("artist").Where("id > 5").String(), ) } + +func TestTruncate(t *testing.T) { + b := &Builder{t: sqlutil.NewTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `TRUNCATE TABLE "artist" RESTART IDENTITY`, + b.TruncateTable("artist").Extra("RESTART IDENTITY").String(), + ) +} diff --git a/builder/template_test.go b/builder/template_test.go index 47a6db0256749b26bd9503a0fc0b669077dfc2fd..157e23bd9bce5654cd5aed660ce65a82292f330a 100644 --- a/builder/template_test.go +++ b/builder/template_test.go @@ -136,7 +136,7 @@ const ( ` defaultTruncateLayout = ` - TRUNCATE TABLE {{.Table}} + TRUNCATE TABLE {{.Table}} {{.Extra}} ` defaultDropDatabaseLayout = ` diff --git a/postgresql/collection.go b/postgresql/collection.go index f06f1f0eb5676bf0ee37e27268bd7aae5204b63c..d5a9fb377de1195c0b11f519d0cfe7b0ded8731e 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -26,11 +26,9 @@ import ( "fmt" "strings" - "github.com/jmoiron/sqlx" "upper.io/db" "upper.io/db/builder" "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" "upper.io/db/util/sqlutil/result" ) @@ -48,33 +46,26 @@ func (t *table) Find(conds ...interface{}) db.Result { // Truncate deletes all rows from the table. func (t *table) Truncate() error { - _, err := t.database.Exec(&sqlgen.Statement{ + stmt := sqlgen.Statement{ Type: sqlgen.Truncate, Table: sqlgen.TableWithName(t.Name()), - }) - if err != nil { - return err + Extra: sqlgen.Extra("RESTART IDENTITY"), } + if _, err := t.database.Exec(&stmt); err != nil { + return err + } return nil } // Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { columnNames, columnValues, err := builder.Map(item) - - if err != nil { - return nil, err - } - - sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - if err != nil { return nil, err } var pKey []string - if pKey, err = t.database.getPrimaryKey(t.Name()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. @@ -82,18 +73,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { } } - stmt := &sqlgen.Statement{ - Type: sqlgen.Insert, - Table: sqlgen.TableWithName(t.Name()), - Columns: sqlgenCols, - Values: sqlgenVals, - } + q := t.database.Builder().InsertInto(t.Name()). + Columns(columnNames...). + Values(columnValues...) // No primary keys defined. if len(pKey) == 0 { - var res sql.Result - if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { + var res sql.Result + if res, err = q.Exec(); err != nil { return nil, err } @@ -104,21 +92,14 @@ func (t *table) Append(item interface{}) (interface{}, error) { return lastID, nil } - var rows *sqlx.Rows - // A primary key was found. - stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) + q.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) - if rows, err = t.database.Query(stmt, sqlgenArgs...); err != nil { - return nil, err - } + var keyMap map[string]interface{} - keyMap := map[string]interface{}{} - if err := sqlutil.FetchRow(rows, &keyMap); err != nil { - rows.Close() + if err = q.Iterator().One(&keyMap); err != nil { return nil, err } - rows.Close() // Does the item satisfy the db.IDSetter interface? if setter, ok := item.(db.IDSetter); ok { diff --git a/postgresql/database.go b/postgresql/database.go index 75f84f8970d692a91c7c9b6d282abab78257ccda..010c3507505e72515f5d91c57a41e4641441177a 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -40,10 +40,6 @@ import ( "upper.io/db/util/sqlutil/tx" ) -var ( - sqlPlaceholder = sqlgen.RawValue(`?`) -) - type database struct { connURL db.ConnectionURL session *sqlx.DB @@ -247,46 +243,19 @@ func (d *database) Collections() (collections []string, err error) { return d.schema.Tables, nil } - // Schema is empty. - // Querying table names. - stmt := &sqlgen.Statement{ - Type: sqlgen.Select, - Columns: sqlgen.JoinColumns( - sqlgen.ColumnWithName(`table_name`), - ), - Table: sqlgen.TableWithName(`information_schema.tables`), - Where: sqlgen.WhereConditions( - &sqlgen.ColumnValue{ - Column: sqlgen.ColumnWithName(`table_schema`), - Operator: `=`, - Value: sqlgen.NewValue(`public`), - }, - ), - } + q := d.Builder().Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ?", "public") - // Executing statement. - var rows *sqlx.Rows - if rows, err = d.Query(stmt); err != nil { - return nil, err + var row struct { + TableName string `db:"table_name"` } - collections = []string{} - - var name string - - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - rows.Close() - return nil, err - } - - // Adding table entry to schema. - d.schema.AddTable(name) - - // Adding table to collections array. - collections = append(collections, name) + iter := q.Iterator() + for iter.Next(&row) { + d.schema.AddTable(row.TableName) + collections = append(collections, row.TableName) } return collections, nil @@ -430,23 +399,18 @@ func (d *database) populateSchema() (err error) { d.schema = schema.NewDatabaseSchema() // Get database name. - stmt := &sqlgen.Statement{ - Type: sqlgen.Select, - Columns: sqlgen.JoinColumns( - sqlgen.RawValue(`CURRENT_DATABASE()`), - ), - } + q := d.Builder().Select(db.Raw{"CURRENT_DATABASE() AS name"}) - var row *sqlx.Row - - if row, err = d.QueryRow(stmt); err != nil { - return err + var row struct { + Name string `db:"name"` } - if err = row.Scan(&d.schema.Name); err != nil { + if err := q.Iterator().One(&row); err != nil { return err } + d.schema.Name = row.Name + if collections, err = d.Collections(); err != nil { return err } @@ -461,43 +425,20 @@ func (d *database) populateSchema() (err error) { } func (d *database) tableExists(names ...string) error { - var stmt *sqlgen.Statement - var err error - var rows *sqlx.Rows + var row map[string]string - for i := range names { + for _, tableName := range names { - if d.schema.HasTable(names[i]) { + if d.schema.HasTable(tableName) { // We already know this table exists. continue } - stmt = &sqlgen.Statement{ - Type: sqlgen.Select, - Table: sqlgen.TableWithName(`information_schema.tables`), - Columns: sqlgen.JoinColumns( - sqlgen.ColumnWithName(`table_name`), - ), - Where: sqlgen.WhereConditions( - &sqlgen.ColumnValue{ - Column: sqlgen.ColumnWithName(`table_catalog`), - Operator: `=`, - Value: sqlPlaceholder, - }, - &sqlgen.ColumnValue{ - Column: sqlgen.ColumnWithName(`table_name`), - Operator: `=`, - Value: sqlPlaceholder, - }, - ), - } - - if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { - return db.ErrCollectionDoesNotExist - } + q := d.Builder().Select("table_name"). + From("information_schema.tables"). + Where("table_catalog = ? AND table_name = ?", d.schema.Name, tableName) - if !rows.Next() { - rows.Close() + if err := q.Iterator().One(&row); err != nil { return db.ErrCollectionDoesNotExist } } @@ -511,54 +452,26 @@ func (d *database) TableColumns(tableName string) ([]string, error) { func (d *database) tableColumns(tableName string) ([]string, error) { - // Making sure this table is allocated. tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } - stmt := &sqlgen.Statement{ - Type: sqlgen.Select, - Table: sqlgen.TableWithName(`information_schema.columns`), - Columns: sqlgen.JoinColumns( - sqlgen.ColumnWithName(`column_name`), - sqlgen.ColumnWithName(`data_type`), - ), - Where: sqlgen.WhereConditions( - &sqlgen.ColumnValue{ - Column: sqlgen.ColumnWithName(`table_catalog`), - Operator: `=`, - Value: sqlPlaceholder, - }, - &sqlgen.ColumnValue{ - Column: sqlgen.ColumnWithName(`table_name`), - Operator: `=`, - Value: sqlPlaceholder, - }, - ), - } + q := d.Builder().Select("column_name", "data_type"). + From("information_schema.columns"). + Where("table_catalog = ? AND table_name = ?", d.schema.Name, tableName) - var rows *sqlx.Rows - var err error + var rows []columnSchemaT - if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { + if err := q.Iterator().All(&rows); err != nil { return nil, err } - tableFields := []columnSchemaT{} + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(rows)) - if err = sqlutil.FetchRows(rows, &tableFields); err != nil { - rows.Close() - return nil, err - } - - rows.Close() - - d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) - - for i := range tableFields { - d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) + for i := range rows { + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, rows[i].Name) } return d.schema.TableInfo[tableName].Columns, nil @@ -571,46 +484,26 @@ func (d *database) getPrimaryKey(tableName string) ([]string, error) { return tableSchema.PrimaryKey, nil } - // Getting primary key. See https://github.com/upper/db/issues/24. - stmt := &sqlgen.Statement{ - Type: sqlgen.Select, - Table: sqlgen.TableWithName(`pg_index, pg_class, pg_attribute`), - Columns: sqlgen.JoinColumns( - sqlgen.ColumnWithName(`pg_attribute.attname`), - ), - Where: sqlgen.WhereConditions( - sqlgen.RawValue(`pg_class.oid = '"`+tableName+`"'::regclass`), - sqlgen.RawValue(`indrelid = pg_class.oid`), - sqlgen.RawValue(`pg_attribute.attrelid = pg_class.oid`), - sqlgen.RawValue(`pg_attribute.attnum = ANY(pg_index.indkey)`), - sqlgen.RawValue(`indisprimary`), - ), - OrderBy: &sqlgen.OrderBy{ - SortColumns: sqlgen.JoinSortColumns( - &sqlgen.SortColumn{ - Column: sqlgen.ColumnWithName(`attname`), - Order: sqlgen.Ascendent, - }, - ), - }, - } - - var rows *sqlx.Rows - var err error + tableSchema.PrimaryKey = make([]string, 0, 1) - if rows, err = d.Query(stmt); err != nil { - return nil, err - } + q := d.Builder().Select("pg_attribute.attname AS pkey"). + From("pg_index", "pg_class", "pg_attribute"). + Where(` + pg_class.oid = '"` + tableName + `"'::regclass + AND indrelid = pg_class.oid + AND pg_attribute.attrelid = pg_class.oid + AND pg_attribute.attnum = ANY(pg_index.indkey) + AND indisprimary + `).OrderBy("pkey") - tableSchema.PrimaryKey = make([]string, 0, 1) + iter := q.Iterator() - for rows.Next() { - var key string - if err = rows.Scan(&key); err != nil { - rows.Close() - return nil, err - } - tableSchema.PrimaryKey = append(tableSchema.PrimaryKey, key) + var row struct { + Key string `db:"pkey"` + } + + for iter.Next(&row) { + tableSchema.PrimaryKey = append(tableSchema.PrimaryKey, row.Key) } return tableSchema.PrimaryKey, nil