From 92517abbe4ca3f77418250a27f5c465d5f62a9b9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Fri, 1 Aug 2014 13:07:45 -0500
Subject: [PATCH] PostgreSQL: Adding support for nullable fields. Issue: #26.

---
 postgresql/collection.go    | 65 +++++++++++++++++++++----
 postgresql/database_test.go | 94 +++++++++++++++++++++++++++++++++++++
 postgresql/layout.go        |  2 +
 util/sqlutil/main.go        | 58 ++++++++++++++++++-----
 4 files changed, 197 insertions(+), 22 deletions(-)

diff --git a/postgresql/collection.go b/postgresql/collection.go
index 927df783..fdbd24f9 100644
--- a/postgresql/collection.go
+++ b/postgresql/collection.go
@@ -218,21 +218,32 @@ func (self *table) Append(item interface{}) (interface{}, error) {
 	var pKey string
 	var columns sqlgen.Columns
 	var values sqlgen.Values
+	var arguments []interface{}
 	var id int64
 
 	cols, vals, err := self.FieldValues(item, toInternal)
 
-	for _, col := range cols {
-		columns = append(columns, sqlgen.Column{col})
+	if err != nil {
+		return nil, err
 	}
 
-	for i := 0; i < len(vals); i++ {
-		values = append(values, sqlPlaceholder)
+	columns = make(sqlgen.Columns, 0, len(cols))
+	for i := range cols {
+		columns = append(columns, sqlgen.Column{cols[i]})
 	}
 
-	// Error ocurred, stop appending.
-	if err != nil {
-		return nil, err
+	arguments = make([]interface{}, 0, len(vals))
+	values = make(sqlgen.Values, 0, len(vals))
+	for i := range vals {
+		switch v := vals[i].(type) {
+		case sqlgen.Value:
+			// Adding value.
+			values = append(values, v)
+		default:
+			// Adding both value and placeholder.
+			values = append(values, sqlPlaceholder)
+			arguments = append(arguments, v)
+		}
 	}
 
 	if pKey, err = self.source.getPrimaryKey(self.tableN(0)); err != nil {
@@ -252,7 +263,7 @@ func (self *table) Append(item interface{}) (interface{}, error) {
 	if pKey == "" {
 		// No primary key found.
 		var res sql.Result
-		if res, err = self.source.doExec(stmt, vals...); err != nil {
+		if res, err = self.source.doExec(stmt, arguments...); err != nil {
 			return nil, err
 		}
 
@@ -266,7 +277,7 @@ func (self *table) Append(item interface{}) (interface{}, error) {
 
 		// A primary key was found.
 		stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING %s`, pKey))
-		if row, err = self.source.doQueryRow(stmt, vals...); err != nil {
+		if row, err = self.source.doQueryRow(stmt, arguments...); err != nil {
 			return nil, err
 		}
 
@@ -306,6 +317,42 @@ func toInternal(val interface{}) interface{} {
 		return t.Format(DateFormat)
 	case time.Duration:
 		return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond)
+	case sql.NullBool:
+		if t.Valid {
+			if t.Bool {
+				return toInternal(t.Bool)
+			} else {
+				return false
+			}
+		} else {
+			return sqlgen.Value{sqlgen.Raw{psqlNull}}
+		}
+	case sql.NullFloat64:
+		if t.Valid {
+			if t.Float64 != 0.0 {
+				return toInternal(t.Float64)
+			} else {
+				return float64(0)
+			}
+		} else {
+			return sqlgen.Value{sqlgen.Raw{psqlNull}}
+		}
+	case sql.NullInt64:
+		if t.Valid {
+			if t.Int64 != 0 {
+				return toInternal(t.Int64)
+			} else {
+				return 0
+			}
+		} else {
+			return sqlgen.Value{sqlgen.Raw{psqlNull}}
+		}
+	case sql.NullString:
+		if t.Valid {
+			return toInternal(t.String)
+		} else {
+			return sqlgen.Value{sqlgen.Raw{psqlNull}}
+		}
 	case bool:
 		if t == true {
 			return `1`
diff --git a/postgresql/database_test.go b/postgresql/database_test.go
index 1480d1e3..09e6a740 100644
--- a/postgresql/database_test.go
+++ b/postgresql/database_test.go
@@ -587,6 +587,100 @@ func TestFunction(t *testing.T) {
 	res.Close()
 }
 
+// Attempts to test nullable fields.
+func TestNullableFields(t *testing.T) {
+	var err error
+	var sess db.Database
+	var col db.Collection
+	var id interface{}
+
+	if sess, err = db.Open(Adapter, settings); err != nil {
+		t.Fatal(err)
+	}
+
+	defer sess.Close()
+
+	type test_t struct {
+		Id              int64           `db:"id,omitempty"`
+		NullStringTest  sql.NullString  `db:"_string"`
+		NullInt64Test   sql.NullInt64   `db:"_int64"`
+		NullFloat64Test sql.NullFloat64 `db:"_float64"`
+		NullBoolTest    sql.NullBool    `db:"_bool"`
+	}
+
+	var test test_t
+
+	if col, err = sess.Collection(`data_types`); err != nil {
+		t.Fatal(err)
+	}
+
+	if err = col.Truncate(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Testing insertion of invalid nulls.
+	test = test_t{
+		NullStringTest:  sql.NullString{"", false},
+		NullInt64Test:   sql.NullInt64{0, false},
+		NullFloat64Test: sql.NullFloat64{0.0, false},
+		NullBoolTest:    sql.NullBool{false, false},
+	}
+	if id, err = col.Append(test_t{}); err != nil {
+		t.Fatal(err)
+	}
+
+	// Testing fetching of invalid nulls.
+	if err = col.Find(db.Cond{"id": id}).One(&test); err != nil {
+		t.Fatal(err)
+	}
+
+	if test.NullInt64Test.Valid {
+		t.Fatalf(`Expecting invalid null.`)
+	}
+	if test.NullFloat64Test.Valid {
+		t.Fatalf(`Expecting invalid null.`)
+	}
+	if test.NullBoolTest.Valid {
+		t.Fatalf(`Expecting invalid null.`)
+	}
+
+	// In PostgreSQL, how we can tell if this is an invalid null?
+
+	// if test.NullStringTest.Valid {
+	// 	t.Fatalf(`Expecting invalid null.`)
+	// }
+
+	// Testing insertion of valid nulls.
+	test = test_t{
+		NullStringTest:  sql.NullString{"", true},
+		NullInt64Test:   sql.NullInt64{0, true},
+		NullFloat64Test: sql.NullFloat64{0.0, true},
+		NullBoolTest:    sql.NullBool{false, true},
+	}
+	if id, err = col.Append(test); err != nil {
+		t.Fatal(err)
+	}
+
+	// Testing fetching of valid nulls.
+	if err = col.Find(db.Cond{"id": id}).One(&test); err != nil {
+		t.Fatal(err)
+	}
+
+	if test.NullInt64Test.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+	if test.NullFloat64Test.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+	if test.NullBoolTest.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+	if test.NullStringTest.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+
+}
+
 // Attempts to delete previously added rows.
 func TestRemove(t *testing.T) {
 	var err error
diff --git a/postgresql/layout.go b/postgresql/layout.go
index ffca4557..a0c326b5 100644
--- a/postgresql/layout.go
+++ b/postgresql/layout.go
@@ -121,4 +121,6 @@ const (
 	pgsqlDropTableLayout = `
 		DROP TABLE {{.Table}}
 	`
+
+	psqlNull = `NULL`
 )
diff --git a/util/sqlutil/main.go b/util/sqlutil/main.go
index cd909862..0ec5636e 100644
--- a/util/sqlutil/main.go
+++ b/util/sqlutil/main.go
@@ -38,6 +38,13 @@ var (
 	reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`)
 )
 
+var (
+	nullInt64Type   = reflect.TypeOf(sql.NullInt64{})
+	nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
+	nullBoolType    = reflect.TypeOf(sql.NullBool{})
+	nullStringType  = reflect.TypeOf(sql.NullString{})
+)
+
 type T struct {
 	Columns []string
 }
@@ -86,16 +93,6 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect
 	var item reflect.Value
 	var err error
 
-	expecting := len(columns)
-
-	// Allocating results.
-	values := make([]*sql.RawBytes, expecting)
-	scanArgs := make([]interface{}, expecting)
-
-	for i := range columns {
-		scanArgs[i] = &values[i]
-	}
-
 	switch item_t.Kind() {
 	case reflect.Map:
 		item = reflect.MakeMap(item_t)
@@ -105,9 +102,17 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect
 		return item, db.ErrExpectingMapOrStruct
 	}
 
-	err = rows.Scan(scanArgs...)
+	expecting := len(columns)
 
-	if err != nil {
+	// Allocating results.
+	values := make([]*sql.RawBytes, expecting)
+	scanArgs := make([]interface{}, expecting)
+
+	for i := range columns {
+		scanArgs[i] = &values[i]
+	}
+
+	if err = rows.Scan(scanArgs...); err != nil {
 		return item, err
 	}
 
@@ -117,6 +122,7 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect
 		if value != nil {
 			// Real column name
 			column := columns[i]
+
 			// Value as string.
 			svalue := string(*value)
 
@@ -153,7 +159,32 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect
 					if destf.IsValid() {
 						if cv.Type() != destf.Type() {
 							if destf.Type().Kind() != reflect.Interface {
-								cv, _ = util.StringToType(svalue, destf.Type())
+								switch destf.Type() {
+								case nullFloat64Type:
+									nullFloat64 := sql.NullFloat64{}
+									if svalue != `` {
+										nullFloat64.Scan(svalue)
+									}
+									cv = reflect.ValueOf(nullFloat64)
+								case nullInt64Type:
+									nullInt64 := sql.NullInt64{}
+									if svalue != `` {
+										nullInt64.Scan(svalue)
+									}
+									cv = reflect.ValueOf(nullInt64)
+								case nullBoolType:
+									nullBool := sql.NullBool{}
+									if svalue != `` {
+										nullBool.Scan(svalue)
+									}
+									cv = reflect.ValueOf(nullBool)
+								case nullStringType:
+									nullString := sql.NullString{}
+									nullString.Scan(svalue)
+									cv = reflect.ValueOf(nullString)
+								default:
+									cv, _ = util.StringToType(svalue, destf.Type())
+								}
 							}
 						}
 						// Copying value.
@@ -162,6 +193,7 @@ func fetchResult(item_t reflect.Type, rows *sql.Rows, columns []string) (reflect
 						}
 					}
 				}
+
 			}
 		}
 	}
-- 
GitLab