diff --git a/builder/builder.go b/builder/builder.go index b594c3c121abeacb8ad0d860fa42a1496d5abe5a..29157104ed152f4b0f5a841789ed17eb7659543d 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -37,8 +37,6 @@ type sqlDatabase interface { Query(stmt *sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) QueryRow(stmt *sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) Exec(stmt *sqlgen.Statement, args ...interface{}) (sql.Result, error) - - TableColumns(tableName string) ([]string, error) } type Builder struct { @@ -120,7 +118,18 @@ func (qi *QueryInserter) Columns(columns ...string) db.QueryInserter { } func (qi *QueryInserter) Values(values ...interface{}) db.QueryInserter { - if len(qi.columns) == 0 || len(values) == len(qi.columns) { + if len(qi.columns) == 0 && len(values) == 1 { + ff, vv, _ := Map(values[0]) + + columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) + + qi.arguments = append(qi.arguments, arguments...) + qi.values = append(qi.values, vals) + + for _, c := range columns.Columns { + qi.columns = append(qi.columns, c) + } + } else if len(qi.columns) == 0 || len(values) == len(qi.columns) { qi.arguments = append(qi.arguments, values...) l := len(values) @@ -204,8 +213,7 @@ type QueryUpdater struct { func (qu *QueryUpdater) Set(terms ...interface{}) db.QueryUpdater { if len(terms) == 1 { - columns, _ := qu.builder.sess.TableColumns(qu.table) - ff, vv, _ := fieldValues(columns, terms[0]) + ff, vv, _ := Map(terms[0]) cvs := make([]sqlgen.Fragment, len(ff)) @@ -216,7 +224,6 @@ func (qu *QueryUpdater) Set(terms ...interface{}) db.QueryUpdater { Value: sqlPlaceholder, } } - qu.columnValues.Append(cvs...) qu.arguments = append(qu.arguments, vv...) } else if len(terms) > 1 { @@ -605,7 +612,7 @@ func (iter *iterator) Close() (err error) { return err } -func fieldValues(columns []string, item interface{}) ([]string, []interface{}, error) { +func Map(item interface{}) ([]string, []interface{}, error) { fields := []string{} values := []interface{}{} @@ -672,7 +679,7 @@ func fieldValues(columns []string, item interface{}) ([]string, []interface{}, e for i, keyV := range mkeys { valv := itemV.MapIndex(keyV) - fields[i] = columnLike(columns, fmt.Sprintf("%v", keyV.Interface())) + fields[i] = fmt.Sprintf("%v", keyV.Interface()) v, err := marshal(valv.Interface()) if err != nil { @@ -681,7 +688,6 @@ func fieldValues(columns []string, item interface{}) ([]string, []interface{}, e values[i] = v } - default: return nil, nil, db.ErrExpectingMapOrStruct } diff --git a/postgresql/collection.go b/postgresql/collection.go index b14360453b06e437124d4c0942ca08a58837cab2..5656e6e616f91670ec1216b375f69c03ebfc5d3a 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -28,6 +28,7 @@ import ( "github.com/jmoiron/sqlx" "upper.io/db" + "upper.io/db/builder" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" "upper.io/db/util/sqlutil/result" @@ -42,7 +43,7 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(conds ...interface{}) db.Result { - return result.NewResult(t.database.Builder(), t, conds) + return result.NewResult(t.database.Builder(), t.Name(), conds) } // Truncate deletes all rows from the table. @@ -51,7 +52,6 @@ func (t *table) Truncate() error { Type: sqlgen.Truncate, Table: sqlgen.TableWithName(t.MainTableName()), }) - if err != nil { return err } @@ -61,8 +61,7 @@ func (t *table) Truncate() error { // Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - - columnNames, columnValues, err := t.FieldValues(item) + columnNames, columnValues, err := builder.Map(item) if err != nil { return nil, err diff --git a/postgresql/database.go b/postgresql/database.go index 6487869e1ea989aadcb45b6cc64d4e54ab438700..2dc9a99e9eec56e9ccb6ab7afbb863dea168a522 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -152,8 +152,6 @@ func (d *database) Open() error { return err } - d.session.Mapper = sqlutil.NewMapper() - d.cachedStatements = cache.NewCache() d.collections = make(map[string]*table) @@ -231,7 +229,6 @@ func (d *database) Collection(names ...string) (db.Collection, error) { col := &table{database: d} col.T.Tables = names - col.T.Mapper = d.session.Mapper for _, name := range names { chunks := strings.SplitN(name, ` `, 2) diff --git a/postgresql/database_test.go b/postgresql/database_test.go index 467eea642e53c9d755108ea44aa0b01cfe5708e1..f98e862e20f0c806330f3635a85afb954ed3e43a 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -2054,6 +2054,19 @@ func TestQueryBuilder(t *testing.T) { String(), ) + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + assert.Equal( `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), diff --git a/util/sqlutil/debug.go b/util/sqlutil/debug.go index 3c997891db4a50bf7d59572c60c2bbbffff31615..2357b59aefea5df0ae84cc9ad153bbff89f02faa 100644 --- a/util/sqlutil/debug.go +++ b/util/sqlutil/debug.go @@ -25,11 +25,17 @@ import ( "fmt" "log" "os" + "regexp" "strings" "upper.io/db" ) +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) + reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`) +) + func init() { if os.Getenv(db.EnvEnableDebug) != "" { db.Debug = true diff --git a/util/sqlutil/fetch.go b/util/sqlutil/fetch.go index a37296286da017fb3a8b1fedd2ca8cba81a65dde..823bf64b12df0aae85dc59cf159679d3950815ca 100644 --- a/util/sqlutil/fetch.go +++ b/util/sqlutil/fetch.go @@ -282,3 +282,12 @@ func fetchResult(itemT reflect.Type, rows *sqlx.Rows, columns []string) (reflect return item, nil } + +func reset(data interface{}) error { + // Resetting element. + v := reflect.ValueOf(data).Elem() + t := v.Type() + z := reflect.Zero(t) + v.Set(z) + return nil +} diff --git a/util/sqlutil/result/result.go b/util/sqlutil/result/result.go index cc2345d0d7700898984dcc4c4e22c11b6125ef95..2e50932e1e4e16bdc991fbe3fed76d2c375a812c 100644 --- a/util/sqlutil/result/result.go +++ b/util/sqlutil/result/result.go @@ -27,7 +27,7 @@ import ( type Result struct { b db.QueryBuilder - dp DataProvider + table string iter db.Iterator limit int offset int @@ -40,10 +40,10 @@ type Result struct { // NewResult creates and results a new result set on the given table, this set // is limited by the given sqlgen.Where conditions. -func NewResult(b db.QueryBuilder, dp DataProvider, conds []interface{}) *Result { +func NewResult(b db.QueryBuilder, table string, conds []interface{}) *Result { return &Result{ b: b, - dp: dp, + table: table, conds: conds, } } @@ -111,7 +111,7 @@ func (r *Result) Next(dst interface{}) (err error) { // Removes the matching items from the collection. func (r *Result) Remove() error { - q := r.b.DeleteFrom(r.dp.Name()). + q := r.b.DeleteFrom(r.table). Where(r.conds...). Limit(r.limit) @@ -130,7 +130,7 @@ func (r *Result) Close() error { // Updates matching items from the collection with values of the given map or // struct. func (r *Result) Update(values interface{}) error { - q := r.b.Update(r.dp.Name()). + q := r.b.Update(r.table). Set(values). Where(r.conds...). Limit(r.limit) @@ -158,7 +158,7 @@ func (r *Result) Count() (uint64, error) { func (r *Result) buildSelect() db.QuerySelector { q := r.b.Select(r.fields...) - q.From(r.dp.Name()) + q.From(r.table) q.Where(r.conds...) q.Limit(r.limit) q.Offset(r.offset) diff --git a/util/sqlutil/result/table.go b/util/sqlutil/result/table.go deleted file mode 100644 index 12076d80629c02daec0635f3e29e40a86cd023de..0000000000000000000000000000000000000000 --- a/util/sqlutil/result/table.go +++ /dev/null @@ -1,15 +0,0 @@ -package result - -import ( - "database/sql" - "github.com/jmoiron/sqlx" - "upper.io/db/util/sqlgen" -) - -type DataProvider interface { - Name() string - Query(*sqlgen.Statement, ...interface{}) (*sqlx.Rows, error) - QueryRow(*sqlgen.Statement, ...interface{}) (*sqlx.Row, error) - Exec(*sqlgen.Statement, ...interface{}) (sql.Result, error) - FieldValues(interface{}) ([]string, []interface{}, error) -} diff --git a/util/sqlutil/sqlutil.go b/util/sqlutil/sqlutil.go index 4faafac3533817824e6bdf54f4bbbb5a08a0b842..3290adfa2e5451eaad77cd63e666a493ece7ada4 100644 --- a/util/sqlutil/sqlutil.go +++ b/util/sqlutil/sqlutil.go @@ -22,159 +22,14 @@ package sqlutil import ( - "database/sql" - "fmt" - "reflect" - "regexp" "strings" - // "crypto/md5" - - "github.com/jmoiron/sqlx/reflectx" - "upper.io/db" -) - -var ( - reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) - reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`) ) -var ( - nullInt64Type = reflect.TypeOf(sql.NullInt64{}) - nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) - nullBoolType = reflect.TypeOf(sql.NullBool{}) - nullStringType = reflect.TypeOf(sql.NullString{}) -) - -// T type is commonly used by adapters to map database/sql values to Go values -// using FieldValues() type T struct { Columns []string - Mapper *reflectx.Mapper Tables []string // Holds table names. } -func (t *T) columnLike(s string) string { - for _, name := range t.Columns { - if normalizeColumn(s) == normalizeColumn(name) { - return name - } - } - return s -} - -func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { - fields := []string{} - values := []interface{}{} - - itemV := reflect.ValueOf(item) - itemT := itemV.Type() - - if itemT.Kind() == reflect.Ptr { - // Single derefence. Just in case user passed a pointer to struct instead of a struct. - item = itemV.Elem().Interface() - itemV = reflect.ValueOf(item) - itemT = itemV.Type() - } - - switch itemT.Kind() { - - case reflect.Struct: - - fieldMap := t.Mapper.TypeMap(itemT).Names - nfields := len(fieldMap) - - values = make([]interface{}, 0, nfields) - fields = make([]string, 0, nfields) - - for _, fi := range fieldMap { - // log.Println("=>", fi.Name, fi.Options) - - fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) - if fld.Kind() == reflect.Ptr && fld.IsNil() { - continue - } - - var value interface{} - if _, ok := fi.Options["stringarray"]; ok { - value = StringArray(fld.Interface().([]string)) - } else if _, ok := fi.Options["int64array"]; ok { - value = Int64Array(fld.Interface().([]int64)) - } else if _, ok := fi.Options["jsonb"]; ok { - value = JsonbType{fld.Interface()} - } else { - value = fld.Interface() - } - - if _, ok := fi.Options["omitempty"]; ok { - if value == fi.Zero.Interface() { - continue - } - } - - // TODO: columnLike stuff...? - - fields = append(fields, fi.Name) - v, err := marshal(value) - if err != nil { - return nil, nil, err - } - values = append(values, v) - } - - case reflect.Map: - nfields := itemV.Len() - values = make([]interface{}, nfields) - fields = make([]string, nfields) - mkeys := itemV.MapKeys() - - for i, keyV := range mkeys { - valv := itemV.MapIndex(keyV) - fields[i] = t.columnLike(fmt.Sprintf("%v", keyV.Interface())) - - v, err := marshal(valv.Interface()) - if err != nil { - return nil, nil, err - } - - values[i] = v - } - - default: - return nil, nil, db.ErrExpectingMapOrStruct - } - - return fields, values, nil -} - -func marshal(v interface{}) (interface{}, error) { - if m, isMarshaler := v.(db.Marshaler); isMarshaler { - var err error - if v, err = m.MarshalDB(); err != nil { - return nil, err - } - } - return v, nil -} - -func reset(data interface{}) error { - // Resetting element. - v := reflect.ValueOf(data).Elem() - t := v.Type() - z := reflect.Zero(t) - v.Set(z) - return nil -} - -// normalizeColumn prepares a column for comparison against another column. -func normalizeColumn(s string) string { - return strings.ToLower(reColumnCompareExclude.ReplaceAllString(s, "")) -} - -// NewMapper creates a reflectx.Mapper -func NewMapper() *reflectx.Mapper { - return reflectx.NewMapper("db") -} - // MainTableName returns the name of the first table. func (t *T) MainTableName() string { return t.NthTableName(0)