diff --git a/.travis.yml b/.travis.yml index a9275520e461363c84bed6f97455877cd15497f4..395591e4d4cd56792144b3acefb28e3aca1f6745 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,4 +39,5 @@ before_script: - mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root mysql script: - - UPPERIO_DB_DEBUG=1 make test +# - UPPERIO_DB_DEBUG=1 make test + - make test diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index e015ff4501c6c15334e9a3f46b8adf5bc0a1f7de..37d786e9f522341d1191b1bc29f544a2591c4428 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -5,8 +5,8 @@ package ADAPTER import ( "database/sql" "flag" - "log" "fmt" + "log" "math/rand" "os" "strconv" @@ -117,7 +117,7 @@ func TestInsertReturning(t *testing.T) { assert.NotZero(t, itemMap["id"], "Must not be zero after inserting") itemStruct := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -133,7 +133,7 @@ func TestInsertReturning(t *testing.T) { assert.Equal(t, uint64(2), count, "Expecting 2 elements") itemStruct2 := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -180,7 +180,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) { assert.NotZero(t, itemMap["id"], "Must not be zero after inserting") itemStruct := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -196,7 +196,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) { assert.Equal(t, uint64(2), count, "Expecting 2 elements") itemStruct2 := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -362,7 +362,7 @@ func TestGetResultsOneByOne(t *testing.T) { // Dumping into a tagged struct. rowStruct2 := struct { - Value1 int64 `db:"id"` + Value1 int64 `db:"id"` Value2 string `db:"name"` }{} @@ -402,7 +402,7 @@ func TestGetResultsOneByOne(t *testing.T) { // Dumping into a slice of structs. allRowsStruct := []struct { - ID int64 `db:"id,omitempty"` + ID int64 `db:"id,omitempty"` Name string `db:"name"` }{} @@ -423,7 +423,7 @@ func TestGetResultsOneByOne(t *testing.T) { // Dumping into a slice of tagged structs. allRowsStruct2 := []struct { - Value1 int64 `db:"id"` + Value1 int64 `db:"id"` Value2 string `db:"name"` }{} @@ -536,7 +536,7 @@ func TestUpdate(t *testing.T) { // Defining destination struct value := struct { - ID int64 `db:"id,omitempty"` + ID int64 `db:"id,omitempty"` Name string `db:"name"` }{} @@ -1080,6 +1080,40 @@ func TestDataTypes(t *testing.T) { assert.Equal(t, testValues, item) } +func TestUpdateWithNullColumn(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + artist := sess.Collection("artist") + err := artist.Truncate() + assert.NoError(t, err) + + type Artist struct { + ID int64 `db:"id,omitempty"` + Name *string `db:"name"` + } + + name := "José" + id, err := artist.Insert(Artist{0, &name}) + assert.NoError(t, err) + + var item Artist + err = artist.Find(id).One(&item) + assert.NoError(t, err) + + assert.NotEqual(t, nil, item.Name) + assert.Equal(t, name, *item.Name) + + artist.Find(db.Cond{"id": id}).Update(Artist{Name: nil}) + assert.NoError(t, err) + + var item2 Artist + err = artist.Find(id).One(&item2) + assert.NoError(t, err) + + assert.Equal(t, (*string)(nil), item2.Name) +} + func TestBatchInsert(t *testing.T) { sess := mustOpen() defer sess.Close() @@ -1130,7 +1164,9 @@ func TestBatchInsertNoColumns(t *testing.T) { go func() { defer batch.Done() for i := 0; i < totalItems; i++ { - value := struct{Name string `db:"name"`}{fmt.Sprintf("artist-%d", i)} + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-%d", i)} batch.Values(value) } }() @@ -1209,7 +1245,7 @@ func TestBuilder(t *testing.T) { err := sess.Collection("artist").Truncate() assert.NoError(t, err) - _, err = sess.InsertInto("artist").Values(struct{ + _, err = sess.InsertInto("artist").Values(struct { Name string `db:"name"` }{"Rinko Kikuchi"}).Exec() assert.NoError(t, err) @@ -1285,7 +1321,6 @@ func TestBuilder(t *testing.T) { assert.Error(t, iter.Err()) } - // Using implicit iterator. q := sess.SelectFrom("artist") err = q.All(&all) @@ -1319,7 +1354,7 @@ func TestExhaustConnectionPool(t *testing.T) { t.Fatal(err) } - tLogf := func(format string, args... interface{}) { + tLogf := func(format string, args ...interface{}) { tMu.Lock() defer tMu.Unlock() t.Logf(format, args...) @@ -1349,7 +1384,7 @@ func TestExhaustConnectionPool(t *testing.T) { // transaction lasts 3 seconds. time.Sleep(time.Second * 3) - switch i%7 { + switch i % 7 { case 0: var account map[string]interface{} if err := tx.Collection("artist").Find().One(&account); err != nil { diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index bbaa08f847246acdf4389caf2b51bca0e2d7971e..03abd592bc7d9b8637c3ccbc10ab3e0c8d6acb2e 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -15,6 +15,16 @@ import ( "upper.io/db.v2/lib/reflectx" ) +type MapOptions struct { + IncludeZeroed bool + IncludeNil bool +} + +var defaultMapOptions = MapOptions{ + IncludeZeroed: false, + IncludeNil: false, +} + type hasIsZero interface { IsZero() bool } @@ -178,14 +188,19 @@ func (b *sqlBuilder) Update(table string) Updater { } // Map receives a pointer to map or struct and maps it to columns and values. -func Map(item interface{}) ([]string, []interface{}, error) { +func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) { var fv fieldValue + if options == nil { + options = &defaultMapOptions + } + itemV := reflect.ValueOf(item) itemT := itemV.Type() if itemT.Kind() == reflect.Ptr { - // Single derefence. Just in case user passed a pointer to struct instead of a struct. + // Single dereference. Just in case the user passes a pointer to struct + // instead of a struct. item = itemV.Elem().Interface() itemV = reflect.ValueOf(item) itemT = itemV.Type() @@ -202,29 +217,51 @@ func Map(item interface{}) ([]string, []interface{}, error) { fv.fields = make([]string, 0, nfields) for _, fi := range fieldMap { + + // Field options + _, tagOmitEmpty := fi.Options["omitempty"] + _, tagStringArray := fi.Options["stringarray"] + _, tagInt64Array := fi.Options["int64array"] + _, tagJSONB := fi.Options["jsonb"] + fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) if fld.Kind() == reflect.Ptr && fld.IsNil() { + if options.IncludeNil || !tagOmitEmpty { + fv.fields = append(fv.fields, fi.Name) + fv.values = append(fv.values, fld.Interface()) + } continue } var value interface{} - if _, ok := fi.Options["stringarray"]; ok { - value = stringArray(fld.Interface().([]string)) - } else if _, ok := fi.Options["int64array"]; ok { - value = int64Array(fld.Interface().([]int64)) - } else if _, ok := fi.Options["jsonb"]; ok { + switch { + case tagStringArray: + v, ok := fld.Interface().([]string) + if !ok { + return nil, nil, fmt.Errorf(`Expecting field %q to be []string (using "stringarray" tag)`, fi.Name) + } + value = stringArray(v) + case tagInt64Array: + v, ok := fld.Interface().([]int64) + if !ok { + return nil, nil, fmt.Errorf(`Expecting field %q to be []int64 (using "int64array" tag)`, fi.Name) + } + value = int64Array(v) + case tagJSONB: value = jsonbType{fld.Interface()} - } else { + default: value = fld.Interface() } - if _, ok := fi.Options["omitempty"]; ok { - if t, ok := fld.Interface().(hasIsZero); ok { - if t.IsZero() { + if !options.IncludeZeroed { + if tagOmitEmpty { + if t, ok := fld.Interface().(hasIsZero); ok { + if t.IsZero() { + continue + } + } else if value == fi.Zero.Interface() { continue } - } else if value == fi.Zero.Interface() { - continue } } diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 00e01e430a7d0a71d5e724be8eab9edcbd2e1c8b..9120f4ee9bafb2056ea308b6653a68f330aa7453 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -539,19 +539,83 @@ func TestInsert(t *testing.T) { }{12, "Chavela Vargas"}).String(), ) - assert.Equal( - `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6)`, - b.InsertInto("artist").Values(struct { - ID int `db:"id"` - Name string `db:"name"` - }{12, "Chavela Vargas"}).Values(struct { - ID int `db:"id"` - Name string `db:"name"` - }{13, "Alondra de la Parra"}).Values(struct { - ID int `db:"id"` - Name string `db:"name"` - }{14, "Haruki Murakami"}).String(), - ) + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{13, "Alondra de la Parra"}). + Values(artistStruct{14, "Haruki Murakami"}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{0, ""}). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{0, "Alondra de la Parra"}). + Values(artistStruct{14, ""}). + Values(artistStruct{0, ""}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6), ($7, $8), ($9, $10)`, + q.String(), + ) + + assert.Equal( + []interface{}{0, "", 12, "Chavela Vargas", 0, "Alondra de la Parra", 14, "", 0, ""}, + q.Arguments(), + ) + } + + { + intRef := func(i int) *int { + if i == 0 { + return nil + } + return &i + } + + strRef := func(s string) *string { + if s == "" { + return nil + } + return &s + } + + type artistStruct struct { + ID *int `db:"id,omitempty"` + Name *string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{intRef(0), strRef("")}). + Values(artistStruct{intRef(12), strRef("Chavela Vargas")}). + Values(artistStruct{intRef(0), strRef("Alondra de la Parra")}). + Values(artistStruct{intRef(14), strRef("")}). + Values(artistStruct{intRef(0), strRef("")}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6), ($7, $8), ($9, $10)`, + q.String(), + ) + + assert.Equal( + []interface{}{intRef(0), strRef(""), intRef(12), strRef("Chavela Vargas"), intRef(0), strRef("Alondra de la Parra"), intRef(14), strRef(""), intRef(0), strRef("")}, + q.Arguments(), + ) + } assert.Equal( `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index 775b13e72bd6d31d8c454d4d185a8829d2c72b88..bef31ae34663812afa32a795a4a4f67784694a87 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -27,6 +27,10 @@ func (qi *inserter) Batch(n int) *BatchInserter { return newBatchInserter(qi.clone(), n) } +func (qi *inserter) Arguments() []interface{} { + return qi.arguments +} + func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) error { l := len(columns) f := make([]exql.Fragment, l) @@ -66,7 +70,7 @@ func (qi *inserter) Columns(columns ...string) Inserter { func (qi *inserter) Values(values ...interface{}) Inserter { if len(values) == 1 { - ff, vv, err := Map(values[0]) + ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true}) if err == nil { columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index 602c1405a71f33501db9baccb4a74aa92dca9b12..d0d6de3dd7f4b681daf148a1b60b7eb1e22d3d5d 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -312,6 +312,9 @@ type Inserter interface { // i.Values(map[string][string]{"name": "MarÃa"}) Values(...interface{}) Inserter + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} + // Returning represents a RETURNING clause. // // RETURNING specifies which columns should be returned after INSERT. diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index d4d108c670a15bf0a48e14b846fe8f9830400118..6c73724b282e9d8569a194de5ea79e2560ccbe95 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -18,7 +18,7 @@ type updater struct { func (qu *updater) Set(terms ...interface{}) Updater { if len(terms) == 1 { - ff, vv, _ := Map(terms[0]) + ff, vv, _ := Map(terms[0], nil) cvs := make([]exql.Fragment, 0, len(ff)) args := make([]interface{}, 0, len(vv)) diff --git a/mysql/collection.go b/mysql/collection.go index 3d967a38d0d9f3a31b833a87a7f374912a4eea05..e9af67b1580798cb3c42ef7095df7fba87aaf7d9 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -75,7 +75,7 @@ func (t *table) Conds(conds ...interface{}) []interface{} { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err } diff --git a/postgresql/collection.go b/postgresql/collection.go index 65ea696dccbe069272b92cd0ffe3258bf0f8b7e5..2a66fc50290a09f7a689ac7abee8cd0bcbc0bb7e 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -75,7 +75,7 @@ func (t *table) Conds(conds ...interface{}) []interface{} { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err } diff --git a/ql/collection.go b/ql/collection.go index 0a0e423d1d8f6016ca8c62d6da7568545c3f689f..7d8017706b55a1563f1d7b9a0cf126fe52421fd8 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -112,7 +112,7 @@ func (t *table) Find(conds ...interface{}) db.Result { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err } diff --git a/sqlite/collection.go b/sqlite/collection.go index 214c7d531fec296b4baddfb2aafc95af32bdbf70..6ad06ad929481b325f9ae4eac15bb74686779621 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -75,7 +75,7 @@ func (t *table) Conds(conds ...interface{}) []interface{} { // Insert inserts an item (map or struct) into the collection. func (t *table) Insert(item interface{}) (interface{}, error) { - columnNames, columnValues, err := sqlbuilder.Map(item) + columnNames, columnValues, err := sqlbuilder.Map(item, nil) if err != nil { return nil, err }