diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index e015ff4501c6c15334e9a3f46b8adf5bc0a1f7de..37d786e9f522341d1191b1bc29f544a2591c4428 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -5,8 +5,8 @@ package ADAPTER import ( "database/sql" "flag" - "log" "fmt" + "log" "math/rand" "os" "strconv" @@ -117,7 +117,7 @@ func TestInsertReturning(t *testing.T) { assert.NotZero(t, itemMap["id"], "Must not be zero after inserting") itemStruct := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -133,7 +133,7 @@ func TestInsertReturning(t *testing.T) { assert.Equal(t, uint64(2), count, "Expecting 2 elements") itemStruct2 := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -180,7 +180,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) { assert.NotZero(t, itemMap["id"], "Must not be zero after inserting") itemStruct := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -196,7 +196,7 @@ func TestInsertReturningWithinTransaction(t *testing.T) { assert.Equal(t, uint64(2), count, "Expecting 2 elements") itemStruct2 := struct { - ID int `db:"id,omitempty"` + ID int `db:"id,omitempty"` Name string `db:"name"` }{ 0, @@ -362,7 +362,7 @@ func TestGetResultsOneByOne(t *testing.T) { // Dumping into a tagged struct. rowStruct2 := struct { - Value1 int64 `db:"id"` + Value1 int64 `db:"id"` Value2 string `db:"name"` }{} @@ -402,7 +402,7 @@ func TestGetResultsOneByOne(t *testing.T) { // Dumping into a slice of structs. allRowsStruct := []struct { - ID int64 `db:"id,omitempty"` + ID int64 `db:"id,omitempty"` Name string `db:"name"` }{} @@ -423,7 +423,7 @@ func TestGetResultsOneByOne(t *testing.T) { // Dumping into a slice of tagged structs. allRowsStruct2 := []struct { - Value1 int64 `db:"id"` + Value1 int64 `db:"id"` Value2 string `db:"name"` }{} @@ -536,7 +536,7 @@ func TestUpdate(t *testing.T) { // Defining destination struct value := struct { - ID int64 `db:"id,omitempty"` + ID int64 `db:"id,omitempty"` Name string `db:"name"` }{} @@ -1080,6 +1080,40 @@ func TestDataTypes(t *testing.T) { assert.Equal(t, testValues, item) } +func TestUpdateWithNullColumn(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + artist := sess.Collection("artist") + err := artist.Truncate() + assert.NoError(t, err) + + type Artist struct { + ID int64 `db:"id,omitempty"` + Name *string `db:"name"` + } + + name := "José" + id, err := artist.Insert(Artist{0, &name}) + assert.NoError(t, err) + + var item Artist + err = artist.Find(id).One(&item) + assert.NoError(t, err) + + assert.NotEqual(t, nil, item.Name) + assert.Equal(t, name, *item.Name) + + artist.Find(db.Cond{"id": id}).Update(Artist{Name: nil}) + assert.NoError(t, err) + + var item2 Artist + err = artist.Find(id).One(&item2) + assert.NoError(t, err) + + assert.Equal(t, (*string)(nil), item2.Name) +} + func TestBatchInsert(t *testing.T) { sess := mustOpen() defer sess.Close() @@ -1130,7 +1164,9 @@ func TestBatchInsertNoColumns(t *testing.T) { go func() { defer batch.Done() for i := 0; i < totalItems; i++ { - value := struct{Name string `db:"name"`}{fmt.Sprintf("artist-%d", i)} + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-%d", i)} batch.Values(value) } }() @@ -1209,7 +1245,7 @@ func TestBuilder(t *testing.T) { err := sess.Collection("artist").Truncate() assert.NoError(t, err) - _, err = sess.InsertInto("artist").Values(struct{ + _, err = sess.InsertInto("artist").Values(struct { Name string `db:"name"` }{"Rinko Kikuchi"}).Exec() assert.NoError(t, err) @@ -1285,7 +1321,6 @@ func TestBuilder(t *testing.T) { assert.Error(t, iter.Err()) } - // Using implicit iterator. q := sess.SelectFrom("artist") err = q.All(&all) @@ -1319,7 +1354,7 @@ func TestExhaustConnectionPool(t *testing.T) { t.Fatal(err) } - tLogf := func(format string, args... interface{}) { + tLogf := func(format string, args ...interface{}) { tMu.Lock() defer tMu.Unlock() t.Logf(format, args...) @@ -1349,7 +1384,7 @@ func TestExhaustConnectionPool(t *testing.T) { // transaction lasts 3 seconds. time.Sleep(time.Second * 3) - switch i%7 { + switch i % 7 { case 0: var account map[string]interface{} if err := tx.Collection("artist").Find().One(&account); err != nil { diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index a155591f08a0d328abbe429ea52d2afdfcac3bf5..03abd592bc7d9b8637c3ccbc10ab3e0c8d6acb2e 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -199,7 +199,8 @@ func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) itemT := itemV.Type() if itemT.Kind() == reflect.Ptr { - // Single derefence. Just in case user passed a pointer to struct instead of a struct. + // Single dereference. Just in case the user passes a pointer to struct + // instead of a struct. item = itemV.Elem().Interface() itemV = reflect.ValueOf(item) itemT = itemV.Type() @@ -216,10 +217,16 @@ func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) fv.fields = make([]string, 0, nfields) for _, fi := range fieldMap { - fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) + // Field options + _, tagOmitEmpty := fi.Options["omitempty"] + _, tagStringArray := fi.Options["stringarray"] + _, tagInt64Array := fi.Options["int64array"] + _, tagJSONB := fi.Options["jsonb"] + + fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) if fld.Kind() == reflect.Ptr && fld.IsNil() { - if options.IncludeNil { + if options.IncludeNil || !tagOmitEmpty { fv.fields = append(fv.fields, fi.Name) fv.values = append(fv.values, fld.Interface()) } @@ -227,18 +234,27 @@ func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) } var value interface{} - if _, ok := fi.Options["stringarray"]; ok { - value = stringArray(fld.Interface().([]string)) - } else if _, ok := fi.Options["int64array"]; ok { - value = int64Array(fld.Interface().([]int64)) - } else if _, ok := fi.Options["jsonb"]; ok { + switch { + case tagStringArray: + v, ok := fld.Interface().([]string) + if !ok { + return nil, nil, fmt.Errorf(`Expecting field %q to be []string (using "stringarray" tag)`, fi.Name) + } + value = stringArray(v) + case tagInt64Array: + v, ok := fld.Interface().([]int64) + if !ok { + return nil, nil, fmt.Errorf(`Expecting field %q to be []int64 (using "int64array" tag)`, fi.Name) + } + value = int64Array(v) + case tagJSONB: value = jsonbType{fld.Interface()} - } else { + default: value = fld.Interface() } if !options.IncludeZeroed { - if _, ok := fi.Options["omitempty"]; ok { + if tagOmitEmpty { if t, ok := fld.Interface().(hasIsZero); ok { if t.IsZero() { continue