From 92517abbe4ca3f77418250a27f5c465d5f62a9b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Fri, 1 Aug 2014 13:07:45 -0500 Subject: [PATCH] PostgreSQL: Adding support for nullable fields. Issue: #26. --- postgresql/collection.go | 65 +++++++++++++++++++++---- postgresql/database_test.go | 94 +++++++++++++++++++++++++++++++++++++ postgresql/layout.go | 2 + util/sqlutil/main.go | 58 ++++++++++++++++++----- 4 files changed, 197 insertions(+), 22 deletions(-) diff --git a/postgresql/collection.go b/postgresql/collection.go index 927df783..fdbd24f9 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -218,21 +218,32 @@ func (self *table) Append(item interface{}) (interface{}, error) { var pKey string var columns sqlgen.Columns var values sqlgen.Values + var arguments []interface{} var id int64 cols, vals, err := self.FieldValues(item, toInternal) - for _, col := range cols { - columns = append(columns, sqlgen.Column{col}) + if err != nil { + return nil, err } - for i := 0; i < len(vals); i++ { - values = append(values, sqlPlaceholder) + columns = make(sqlgen.Columns, 0, len(cols)) + for i := range cols { + columns = append(columns, sqlgen.Column{cols[i]}) } - // Error ocurred, stop appending. - if err != nil { - return nil, err + arguments = make([]interface{}, 0, len(vals)) + values = make(sqlgen.Values, 0, len(vals)) + for i := range vals { + switch v := vals[i].(type) { + case sqlgen.Value: + // Adding value. + values = append(values, v) + default: + // Adding both value and placeholder. + values = append(values, sqlPlaceholder) + arguments = append(arguments, v) + } } if pKey, err = self.source.getPrimaryKey(self.tableN(0)); err != nil { @@ -252,7 +263,7 @@ func (self *table) Append(item interface{}) (interface{}, error) { if pKey == "" { // No primary key found. var res sql.Result - if res, err = self.source.doExec(stmt, vals...); err != nil { + if res, err = self.source.doExec(stmt, arguments...); err != nil { return nil, err } @@ -266,7 +277,7 @@ func (self *table) Append(item interface{}) (interface{}, error) { // A primary key was found. stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING %s`, pKey)) - if row, err = self.source.doQueryRow(stmt, vals...); err != nil { + if row, err = self.source.doQueryRow(stmt, arguments...); err != nil { return nil, err } @@ -306,6 +317,42 @@ func toInternal(val interface{}) interface{} { return t.Format(DateFormat) case time.Duration: return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond) + case sql.NullBool: + if t.Valid { + if t.Bool { + return toInternal(t.Bool) + } else { + return false + } + } else { + return sqlgen.Value{sqlgen.Raw{psqlNull}} + } + case sql.NullFloat64: + if t.Valid { + if t.Float64 != 0.0 { + return toInternal(t.Float64) + } else { + return float64(0) + } + } else { + return sqlgen.Value{sqlgen.Raw{psqlNull}} + } + case sql.NullInt64: + if t.Valid { + if t.Int64 != 0 { + return toInternal(t.Int64) + } else { + return 0 + } + } else { + return sqlgen.Value{sqlgen.Raw{psqlNull}} + } + case sql.NullString: + if t.Valid { + return toInternal(t.String) + } else { + return sqlgen.Value{sqlgen.Raw{psqlNull}} + } case bool: if t == true { return `1` diff --git a/postgresql/database_test.go b/postgresql/database_test.go index 1480d1e3..09e6a740 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -587,6 +587,100 @@ func TestFunction(t *testing.T) { res.Close() } +// Attempts to test nullable fields. +func TestNullableFields(t *testing.T) { + var err error + var sess db.Database + var col db.Collection + var id interface{} + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + type test_t struct { + Id int64 `db:"id,omitempty"` + NullStringTest sql.NullString `db:"_string"` + NullInt64Test sql.NullInt64 `db:"_int64"` + NullFloat64Test sql.NullFloat64 `db:"_float64"` + NullBoolTest sql.NullBool `db:"_bool"` + } + + var test test_t + + if col, err = sess.Collection(`data_types`); err != nil { + t.Fatal(err) + } + + if err = col.Truncate(); err != nil { + t.Fatal(err) + } + + // Testing insertion of invalid nulls. + test = test_t{ + NullStringTest: sql.NullString{"", false}, + NullInt64Test: sql.NullInt64{0, false}, + NullFloat64Test: sql.NullFloat64{0.0, false}, + NullBoolTest: sql.NullBool{false, false}, + } + if id, err = col.Append(test_t{}); err != nil { + t.Fatal(err) + } + + // Testing fetching of invalid nulls. + if err = col.Find(db.Cond{"id": id}).One(&test); err != nil { + t.Fatal(err) + } + + if test.NullInt64Test.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullFloat64Test.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullBoolTest.Valid { + t.Fatalf(`Expecting invalid null.`) + } + + // In PostgreSQL, how we can tell if this is an invalid null? + + // if test.NullStringTest.Valid { + // t.Fatalf(`Expecting invalid null.`) + // } + + // Testing insertion of valid nulls. + test = test_t{ + NullStringTest: sql.NullString{"", true}, + NullInt64Test: sql.NullInt64{0, true}, + NullFloat64Test: sql.NullFloat64{0.0, true}, + NullBoolTest: sql.NullBool{false, true}, + } + if id, err = col.Append(test); err != nil { + t.Fatal(err) + } + + // Testing fetching of valid nulls. + if err = col.Find(db.Cond{"id": id}).One(&test); err != nil { + t.Fatal(err) + } + + if test.NullInt64Test.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullFloat64Test.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullBoolTest.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullStringTest.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + +} + // Attempts to delete previously added rows. func TestRemove(t *testing.T) { var err error diff --git a/postgresql/layout.go b/postgresql/layout.go index ffca4557..a0c326b5 100644 --- a/postgresql/layout.go +++ b/postgresql/layout.go @@ -121,4 +121,6 @@ const ( pgsqlDropTableLayout = ` DROP TABLE {{.Table}} ` + + psqlNull = `NULL` ) diff --git a/util/sqlutil/main.go b/util/sqlutil/main.go index cd909862..0ec5636e 100644 --- a/util/sqlutil/main.go +++ b/util/sqlutil/main.go @@ -38,6 +38,13 @@ var ( reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) ) +var ( + nullInt64Type = reflect.TypeOf(sql.NullInt64{}) + nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) + nullBoolType = reflect.TypeOf(sql.NullBool{}) + nullStringType = reflect.TypeOf(sql.NullString{}) +) + type T struct { Columns []string } @@ -86,16 +93,6 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect var item reflect.Value var err error - expecting := len(columns) - - // Allocating results. - values := make([]*sql.RawBytes, expecting) - scanArgs := make([]interface{}, expecting) - - for i := range columns { - scanArgs[i] = &values[i] - } - switch item_t.Kind() { case reflect.Map: item = reflect.MakeMap(item_t) @@ -105,9 +102,17 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect return item, db.ErrExpectingMapOrStruct } - err = rows.Scan(scanArgs...) + expecting := len(columns) - if err != nil { + // Allocating results. + values := make([]*sql.RawBytes, expecting) + scanArgs := make([]interface{}, expecting) + + for i := range columns { + scanArgs[i] = &values[i] + } + + if err = rows.Scan(scanArgs...); err != nil { return item, err } @@ -117,6 +122,7 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect if value != nil { // Real column name column := columns[i] + // Value as string. svalue := string(*value) @@ -153,7 +159,32 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect if destf.IsValid() { if cv.Type() != destf.Type() { if destf.Type().Kind() != reflect.Interface { - cv, _ = util.StringToType(svalue, destf.Type()) + switch destf.Type() { + case nullFloat64Type: + nullFloat64 := sql.NullFloat64{} + if svalue != `` { + nullFloat64.Scan(svalue) + } + cv = reflect.ValueOf(nullFloat64) + case nullInt64Type: + nullInt64 := sql.NullInt64{} + if svalue != `` { + nullInt64.Scan(svalue) + } + cv = reflect.ValueOf(nullInt64) + case nullBoolType: + nullBool := sql.NullBool{} + if svalue != `` { + nullBool.Scan(svalue) + } + cv = reflect.ValueOf(nullBool) + case nullStringType: + nullString := sql.NullString{} + nullString.Scan(svalue) + cv = reflect.ValueOf(nullString) + default: + cv, _ = util.StringToType(svalue, destf.Type()) + } } } // Copying value. @@ -162,6 +193,7 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect } } } + } } } -- GitLab