diff --git a/postgresql/local_test.go b/postgresql/local_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0a17bbe835e71ed50bda2db2240691c1959bfbf0 --- /dev/null +++ b/postgresql/local_test.go @@ -0,0 +1,83 @@ +package postgresql + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "upper.io/db.v2" +) + +func TestStringAndInt64Array(t *testing.T) { + sess := mustOpen() + driver := sess.Driver().(*sql.DB) + + defer func() { + driver.Exec(`DROP TABLE IF EXISTS array_types`) + sess.Close() + }() + + if _, err := driver.Exec(` + CREATE TABLE array_types ( + id serial primary key, + integers bigint[] DEFAULT NULL, + strings varchar(64)[] + )`); err != nil { + assert.NoError(t, err) + } + + arrayTypes := sess.Collection("array_types") + err := arrayTypes.Truncate() + assert.NoError(t, err) + + type arrayType struct { + ID int64 `db:"id,pk"` + Integers []int64 `db:"integers,int64array"` + Strings []string `db:"strings,stringarray"` + } + + tt := []arrayType{ + // Test nil arrays. + arrayType{ + ID: 1, + Integers: nil, + Strings: nil, + }, + + // Test empty arrays. + arrayType{ + ID: 2, + Integers: []int64{}, + Strings: []string{}, + }, + + // Test non-empty arrays. + arrayType{ + ID: 3, + Integers: []int64{1, 2, 3}, + Strings: []string{"1", "2", "3"}, + }, + } + + for _, item := range tt { + id, err := arrayTypes.Insert(item) + assert.NoError(t, err) + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var itemCheck arrayType + err = arrayTypes.Find(db.Cond{"id": id}).One(&itemCheck) + assert.NoError(t, err) + assert.Len(t, itemCheck.Integers, len(item.Integers)) + assert.Len(t, itemCheck.Strings, len(item.Strings)) + + // Check nil/zero values just to make sure that the arrays won't + // be JSON-marshalled into `null` instead of empty array `[]`. + assert.NotNil(t, itemCheck.Integers) + assert.NotNil(t, itemCheck.Strings) + assert.NotZero(t, itemCheck.Integers) + assert.NotZero(t, itemCheck.Strings) + } +} diff --git a/sqlbuilder/scanner.go b/sqlbuilder/scanner.go index 8b4fda6b6b898e818b00b37ad423cb58d82e5bef..9e841d431182b47158965e75befdf678a4b35a8a 100644 --- a/sqlbuilder/scanner.go +++ b/sqlbuilder/scanner.go @@ -28,6 +28,7 @@ import ( "errors" "strconv" "strings" + "upper.io/db.v2" ) @@ -141,6 +142,7 @@ type int64Array []int64 func (a *int64Array) Scan(src interface{}) error { if src == nil { + *a = int64Array{} return nil } b, ok := src.([]byte) @@ -152,7 +154,7 @@ func (a *int64Array) Scan(src interface{}) error { } s := string(b)[1 : len(b)-1] - var results []int64 + results := []int64{} if s != "" { parts := strings.Split(s, ",") for _, n := range parts {