diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index bbaa08f847246acdf4389caf2b51bca0e2d7971e..a155591f08a0d328abbe429ea52d2afdfcac3bf5 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -15,6 +15,16 @@ import ( "upper.io/db.v2/lib/reflectx" ) +type MapOptions struct { + IncludeZeroed bool + IncludeNil bool +} + +var defaultMapOptions = MapOptions{ + IncludeZeroed: false, + IncludeNil: false, +} + type hasIsZero interface { IsZero() bool } @@ -178,9 +188,13 @@ func (b *sqlBuilder) Update(table string) Updater { } // Map receives a pointer to map or struct and maps it to columns and values. -func Map(item interface{}) ([]string, []interface{}, error) { +func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) { var fv fieldValue + if options == nil { + options = &defaultMapOptions + } + itemV := reflect.ValueOf(item) itemT := itemV.Type() @@ -203,7 +217,12 @@ func Map(item interface{}) ([]string, []interface{}, error) { for _, fi := range fieldMap { fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) + if fld.Kind() == reflect.Ptr && fld.IsNil() { + if options.IncludeNil { + fv.fields = append(fv.fields, fi.Name) + fv.values = append(fv.values, fld.Interface()) + } continue } @@ -218,13 +237,15 @@ func Map(item interface{}) ([]string, []interface{}, error) { value = fld.Interface() } - if _, ok := fi.Options["omitempty"]; ok { - if t, ok := fld.Interface().(hasIsZero); ok { - if t.IsZero() { + if !options.IncludeZeroed { + if _, ok := fi.Options["omitempty"]; ok { + if t, ok := fld.Interface().(hasIsZero); ok { + if t.IsZero() { + continue + } + } else if value == fi.Zero.Interface() { continue } - } else if value == fi.Zero.Interface() { - continue } } diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 00e01e430a7d0a71d5e724be8eab9edcbd2e1c8b..9120f4ee9bafb2056ea308b6653a68f330aa7453 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -539,19 +539,83 @@ func TestInsert(t *testing.T) { }{12, "Chavela Vargas"}).String(), ) - assert.Equal( - `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6)`, - b.InsertInto("artist").Values(struct { - ID int `db:"id"` - Name string `db:"name"` - }{12, "Chavela Vargas"}).Values(struct { - ID int `db:"id"` - Name string `db:"name"` - }{13, "Alondra de la Parra"}).Values(struct { - ID int `db:"id"` - Name string `db:"name"` - }{14, "Haruki Murakami"}).String(), - ) + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{13, "Alondra de la Parra"}). + Values(artistStruct{14, "Haruki Murakami"}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{0, ""}). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{0, "Alondra de la Parra"}). + Values(artistStruct{14, ""}). + Values(artistStruct{0, ""}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6), ($7, $8), ($9, $10)`, + q.String(), + ) + + assert.Equal( + []interface{}{0, "", 12, "Chavela Vargas", 0, "Alondra de la Parra", 14, "", 0, ""}, + q.Arguments(), + ) + } + + { + intRef := func(i int) *int { + if i == 0 { + return nil + } + return &i + } + + strRef := func(s string) *string { + if s == "" { + return nil + } + return &s + } + + type artistStruct struct { + ID *int `db:"id,omitempty"` + Name *string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{intRef(0), strRef("")}). + Values(artistStruct{intRef(12), strRef("Chavela Vargas")}). + Values(artistStruct{intRef(0), strRef("Alondra de la Parra")}). + Values(artistStruct{intRef(14), strRef("")}). + Values(artistStruct{intRef(0), strRef("")}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6), ($7, $8), ($9, $10)`, + q.String(), + ) + + assert.Equal( + []interface{}{intRef(0), strRef(""), intRef(12), strRef("Chavela Vargas"), intRef(0), strRef("Alondra de la Parra"), intRef(14), strRef(""), intRef(0), strRef("")}, + q.Arguments(), + ) + } assert.Equal( `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index 775b13e72bd6d31d8c454d4d185a8829d2c72b88..bef31ae34663812afa32a795a4a4f67784694a87 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -27,6 +27,10 @@ func (qi *inserter) Batch(n int) *BatchInserter { return newBatchInserter(qi.clone(), n) } +func (qi *inserter) Arguments() []interface{} { + return qi.arguments +} + func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) error { l := len(columns) f := make([]exql.Fragment, l) @@ -66,7 +70,7 @@ func (qi *inserter) Columns(columns ...string) Inserter { func (qi *inserter) Values(values ...interface{}) Inserter { if len(values) == 1 { - ff, vv, err := Map(values[0]) + ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true}) if err == nil { columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index 602c1405a71f33501db9baccb4a74aa92dca9b12..d0d6de3dd7f4b681daf148a1b60b7eb1e22d3d5d 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -312,6 +312,9 @@ type Inserter interface { // i.Values(map[string][string]{"name": "MarÃa"}) Values(...interface{}) Inserter + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} + // Returning represents a RETURNING clause. // // RETURNING specifies which columns should be returned after INSERT. diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index d4d108c670a15bf0a48e14b846fe8f9830400118..6c73724b282e9d8569a194de5ea79e2560ccbe95 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -18,7 +18,7 @@ type updater struct { func (qu *updater) Set(terms ...interface{}) Updater { if len(terms) == 1 { - ff, vv, _ := Map(terms[0]) + ff, vv, _ := Map(terms[0], nil) cvs := make([]exql.Fragment, 0, len(ff)) args := make([]interface{}, 0, len(vv)) diff --git a/mysql/collection.go b/mysql/collection.go index 3d967a38d0d9f3a31b833a87a7f374912a4eea05..e9af67b1580798cb3c42ef7095df7fba87aaf7d9 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -75,7 +75,7 @@ func (t *table) Conds(conds ...interface{}) []interface{} { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err } diff --git a/postgresql/collection.go b/postgresql/collection.go index 65ea696dccbe069272b92cd0ffe3258bf0f8b7e5..2a66fc50290a09f7a689ac7abee8cd0bcbc0bb7e 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -75,7 +75,7 @@ func (t *table) Conds(conds ...interface{}) []interface{} { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err } diff --git a/ql/collection.go b/ql/collection.go index 0a0e423d1d8f6016ca8c62d6da7568545c3f689f..7d8017706b55a1563f1d7b9a0cf126fe52421fd8 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -112,7 +112,7 @@ func (t *table) Find(conds ...interface{}) db.Result { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err } diff --git a/sqlite/collection.go b/sqlite/collection.go index 214c7d531fec296b4baddfb2aafc95af32bdbf70..6ad06ad929481b325f9ae4eac15bb74686779621 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -75,7 +75,7 @@ func (t *table) Conds(conds ...interface{}) []interface{} { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err }