diff --git a/postgresql/collection.go b/postgresql/collection.go index ea1e3a687dcd136fcd7c208712369e1bc0e6ce78..cfcc4b47a5c7c804fee314ea4dc0951966088b88 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -47,50 +47,52 @@ func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { switch t := term.(type) { case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) + for i := range t { + w, v := whereValues(t[i]) args = append(args, v...) - where = append(where, w...) + where.Conditions = append(where.Conditions, w.Conditions...) } + return case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) + var op sqlgen.And + for i := range t { + k, v := whereValues(t[i]) args = append(args, v...) - and = append(and, k...) + op.Conditions = append(op.Conditions, k.Conditions...) } - where = append(where, and) + where.Conditions = append(where.Conditions, &op) + return case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) + var op sqlgen.Or + for i := range t { + w, v := whereValues(t[i]) args = append(args, v...) - or = append(or, k...) + op.Conditions = append(op.Conditions, w.Conditions...) } - where = append(where, or) + where.Conditions = append(where.Conditions, &op) + return case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) + if s, ok := t.Value.(string); ok { + where.Conditions = append(where.Conditions, sqlgen.RawValue(s)) } + return case db.Cond: - k, v := conditionValues(t) + cv, v := columnValues(t) args = append(args, v...) - for _, kk := range k { - where = append(where, kk) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) } + return case db.Constrainer: - k, v := conditionValues(t.Constraint()) + cv, v := columnValues(t.Constraint()) args = append(args, v...) - for _, kk := range k { - where = append(where, kk) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) + return } - return where, args + panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), term)) } func interfaceArgs(value interface{}) (args []interface{}) { @@ -122,17 +124,17 @@ func interfaceArgs(value interface{}) (args []interface{}) { return args } -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { +func columnValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { args = []interface{}{} for column, value := range cond { - var columnValue sqlgen.ColumnValue + columnValue := sqlgen.ColumnValue{} // Guessing operator from input, or using a default one. column := strings.TrimSpace(column) chunks := strings.SplitN(column, ` `, 2) - columnValue.Column = sqlgen.Column{chunks[0]} + columnValue.Column = sqlgen.ColumnWithName(chunks[0]) if len(chunks) > 1 { columnValue.Operator = chunks[1] @@ -142,30 +144,29 @@ func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []int switch value := value.(type) { case db.Func: - // Catches functions. v := interfaceArgs(value.Args) columnValue.Operator = value.Name if v == nil { // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} + columnValue.Value = sqlgen.RawValue(`()`) } else { // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) } args = append(args, v...) default: - // Catches everything else. v := interfaceArgs(value) + l := len(v) if v == nil || l == 0 { // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{psqlNull}} + columnValue.Value = sqlgen.RawValue(psqlNull) } else { if l > 1 { // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) } else { // Single value given. columnValue.Value = sqlPlaceholder @@ -174,7 +175,7 @@ func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []int } } - columnValues = append(columnValues, columnValue) + columnValues.ColumnValues = append(columnValues.ColumnValues, &columnValue) } return columnValues, args @@ -205,8 +206,8 @@ func (t *table) tableN(i int) string { // Deletes all the rows within the collection. func (t *table) Truncate() error { _, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{t.tableN(0)}, + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.tableN(0)), }) if err != nil { @@ -225,27 +226,30 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - var columns sqlgen.Columns + columns := new(sqlgen.Columns) - columns = make(sqlgen.Columns, 0, len(cols)) + columns.Columns = make([]sqlgen.Fragment, 0, len(cols)) for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) + columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(cols[i])) } - var values sqlgen.Values + values := new(sqlgen.Values) var arguments []interface{} arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) + values.Values = make([]sqlgen.Fragment, 0, len(vals)) for i := range vals { switch v := vals[i].(type) { + case *sqlgen.Value: + // Adding value. + values.Values = append(values.Values, v) case sqlgen.Value: // Adding value. - values = append(values, v) + values.Values = append(values.Values, &v) default: // Adding both value and placeholder. - values = append(values, sqlPlaceholder) + values.Values = append(values.Values, sqlPlaceholder) arguments = append(arguments, v) } } @@ -260,8 +264,8 @@ func (t *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{t.tableN(0)}, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.tableN(0)), Columns: columns, Values: values, } diff --git a/postgresql/database.go b/postgresql/database.go index 158f9a8d04db07ffe47e425561ab2279e20ded1c..401dd31b6ab448a053728bf2b48fa3dddd8dbf5a 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -48,7 +48,7 @@ var ( template *sqlgen.Template // Query statement placeholder - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) type source struct { @@ -204,20 +204,18 @@ func (s *source) Collections() (collections []string, err error) { // Querying table names. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Table: sqlgen.Table{ - `information_schema.tables`, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_schema`}, - `=`, - sqlgen.Value{`public`}, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Table: sqlgen.TableWithName(`information_schema.tables`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlgen.NewValue(`public`), }, - }, + ), } // Executing statement. @@ -266,8 +264,8 @@ func (s *source) Use(database string) (err error) { // Drops the currently active database. func (s *source) Drop() error { _, err := s.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlDropDatabase, - Database: sqlgen.Database{s.schema.Name}, + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(s.schema.Name), }) return err } @@ -409,10 +407,10 @@ func (s *source) populateSchema() (err error) { // Get database name. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {sqlgen.Raw{`CURRENT_DATABASE()`}}, - }, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.RawValue(`CURRENT_DATABASE()`), + ), } var row *sqlx.Row @@ -453,15 +451,23 @@ func (s *source) tableExists(names ...string) error { } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.tables`}, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.tables`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_catalog`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { @@ -488,26 +494,24 @@ func (s *source) tableColumns(tableName string) ([]string, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{ - `information_schema.columns`, - }, - Columns: sqlgen.Columns{ - {`column_name`}, - {`data_type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_catalog`}, - `=`, - sqlPlaceholder, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.columns`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`column_name`), + sqlgen.ColumnWithName(`data_type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_catalog`), + Operator: `=`, + Value: sqlPlaceholder, }, - sqlgen.ColumnValue{ - sqlgen.Column{`table_name`}, - `=`, - sqlPlaceholder, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, }, - }, + ), } var rows *sqlx.Rows @@ -543,25 +547,25 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { // Getting primary key. See https://github.com/upper/db/issues/24. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`}, - Columns: sqlgen.Columns{ - {`pg_attribute.attname`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`pg_class.oid`}, `=`, sqlgen.Value{sqlgen.Raw{`'"` + tableName + `"'::regclass`}}}, - sqlgen.ColumnValue{sqlgen.Column{`indrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, - sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, - sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attnum`}, `=`, sqlgen.Value{sqlgen.Raw{`any(pg_index.indkey)`}}}, - sqlgen.Raw{`indisprimary`}, - }, - OrderBy: sqlgen.OrderBy{ - sqlgen.SortColumns{ - { - sqlgen.Column{`attname`}, - sqlgen.SqlSortAsc, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`pg_index, pg_class, pg_attribute`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`pg_attribute.attname`), + ), + Where: sqlgen.WhereConditions( + sqlgen.RawValue(`pg_class.oid = '`+tableName+`'::regclass`), + sqlgen.RawValue(`indrelid = pg_class.oid`), + sqlgen.RawValue(`pg_attribute.attrelid = pg_class.oid`), + sqlgen.RawValue(`pg_attribute.attnum = ANY(pg_index.indkey)`), + sqlgen.RawValue(`indisprimary`), + ), + OrderBy: &sqlgen.OrderBy{ + SortColumns: sqlgen.JoinSortColumns( + &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(`attname`), + Order: sqlgen.Ascendent, }, - }, + ), }, } diff --git a/postgresql/database_test.go b/postgresql/database_test.go index fd83b09a0d5db685c111d56cdde7560b9e42d4a9..867185e864530a9809218e15bf5c5ef21f32d67b 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -1600,7 +1600,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec(`TRUNCATE TABLE "artist"`); err != nil { b.Fatal(err) @@ -1654,7 +1654,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/postgresql/layout.go b/postgresql/layout.go index 1a1bd8d8ac6690ec932c23e8fdfbed664f33e790..930e6c8bc58fc643b58547d2ad0a119289607d04 100644 --- a/postgresql/layout.go +++ b/postgresql/layout.go @@ -24,7 +24,7 @@ package postgresql const ( pgsqlColumnSeparator = `.` pgsqlIdentifierSeparator = `, ` - pgsqlIdentifierQuote = `"{{.Raw}}"` + pgsqlIdentifierQuote = `"{{.Value}}"` pgsqlValueSeparator = `, ` pgsqlValueQuote = `'{{.}}'` pgsqlAndKeyword = `AND` @@ -38,7 +38,7 @@ const ( pgsqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` pgsqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` pgsqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlSortByColumnLayout = `{{.Column}} {{.Sort}}` + pgsqlSortByColumnLayout = `{{.Column}} {{.Order}}` pgsqlOrderByLayout = ` {{if .SortColumns}} diff --git a/postgresql/result.go b/postgresql/result.go index de0bb1d47053d133eeae423290f27c0a4360540a..b1d145b09dde3f6ce05a968f44f90b059f1100fe 100644 --- a/postgresql/result.go +++ b/postgresql/result.go @@ -53,14 +53,14 @@ func (r *result) setCursor() error { // We need a cursor, if the cursor does not exists yet then we create one. if r.cursor == nil { r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(r.table.Name()), + Columns: &r.columns, Limit: r.limit, Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, + Where: &r.where, + OrderBy: &r.orderBy, + GroupBy: &r.groupBy, }, r.arguments...) } return err @@ -88,20 +88,18 @@ func (r *result) Skip(n uint) db.Result { // Used to group results that have the same value in the same column or // columns. func (r *result) Group(fields ...interface{}) db.Result { + var columns []sqlgen.Fragment - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) + for i := range fields { + switch v := fields[i].(type) { + case string: + columns = append(columns, sqlgen.ColumnWithName(v)) + case sqlgen.Fragment: + columns = append(columns, v) } } - r.groupBy = groupByColumns + r.groupBy = *sqlgen.GroupByColumns(columns...) return r } @@ -111,37 +109,36 @@ func (r *result) Group(fields ...interface{}) db.Result { // used otherwise. func (r *result) Sort(fields ...interface{}) db.Result { - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) + var sortColumns sqlgen.SortColumns - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn + for i := range fields { + var sort *sqlgen.SortColumn switch value := fields[i].(type) { case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)), + Order: sqlgen.Ascendent, } case string: if strings.HasPrefix(value, `-`) { // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value[1:]), + Order: sqlgen.Descendent, } } else { // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value), + Order: sqlgen.Ascendent, } } } - sortColumns = append(sortColumns, sort) + sortColumns.Columns = append(sortColumns.Columns, sort) } - r.orderBy.SortColumns = sortColumns + r.orderBy.SortColumns = &sortColumns return r } @@ -149,11 +146,10 @@ func (r *result) Sort(fields ...interface{}) db.Result { // Retrieves only the given fields. func (r *result) Select(fields ...interface{}) db.Result { - r.columns = make(sqlgen.Columns, 0, len(fields)) + r.columns = sqlgen.Columns{} - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column + for i := range fields { + var col sqlgen.Fragment switch value := fields[i].(type) { case db.Func: v := interfaceArgs(value.Args) @@ -167,13 +163,13 @@ func (r *result) Select(fields ...interface{}) db.Result { } s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) } - col = sqlgen.Column{sqlgen.Raw{s}} + col = sqlgen.RawValue(s) case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} + col = sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)) default: - col = sqlgen.Column{value} + col = sqlgen.ColumnWithName(fmt.Sprintf(`%v`, value)) } - r.columns = append(r.columns, col) + r.columns.Columns = append(r.columns.Columns, col) } return r @@ -238,9 +234,9 @@ func (r *result) Remove() error { var err error _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, + Type: sqlgen.Delete, + Table: sqlgen.TableWithName(r.table.Name()), + Where: &r.where, }, r.arguments...) return err @@ -256,21 +252,19 @@ func (r *result) Update(values interface{}) error { return err } - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) + cvs := new(sqlgen.ColumnValues) - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) + for i := range ff { + cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: "=", Value: sqlPlaceholder}) } vv = append(vv, r.arguments...) _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, + Type: sqlgen.Update, + Table: sqlgen.TableWithName(r.table.Name()), ColumnValues: cvs, - Where: r.where, + Where: &r.where, }, vv...) return err @@ -290,9 +284,9 @@ func (r *result) Count() (uint64, error) { var count counter row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, + Type: sqlgen.Count, + Table: sqlgen.TableWithName(r.table.Name()), + Where: &r.where, }, r.arguments...) if err != nil {