From 3bad2a8555605e5b6f76f0fe57967f238c5d2bfd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Mon, 4 Aug 2014 19:31:30 -0500
Subject: [PATCH] SQLite: Adding support for nullable fields. Issue: #26.

---
 sqlite/collection.go    | 70 +++++++++++++++++++++++++-----
 sqlite/database_test.go | 94 +++++++++++++++++++++++++++++++++++++++++
 sqlite/layout.go        |  2 +
 3 files changed, 155 insertions(+), 11 deletions(-)

diff --git a/sqlite/collection.go b/sqlite/collection.go
index 5139885e..3269d91d 100644
--- a/sqlite/collection.go
+++ b/sqlite/collection.go
@@ -27,6 +27,7 @@ import (
 	"strings"
 	"time"
 
+	"database/sql"
 	"menteslibres.net/gosexy/to"
 	"upper.io/db"
 	"upper.io/db/util/sqlgen"
@@ -213,31 +214,42 @@ func (self *table) Truncate() error {
 
 // Appends an item (map or struct) into the collection.
 func (self *table) Append(item interface{}) (interface{}, error) {
-
-	cols, vals, err := self.FieldValues(item, toInternal)
-
+	var arguments []interface{}
 	var columns sqlgen.Columns
 	var values sqlgen.Values
 
-	for _, col := range cols {
-		columns = append(columns, sqlgen.Column{col})
-	}
-
-	for i := 0; i < len(vals); i++ {
-		values = append(values, sqlPlaceholder)
-	}
+	cols, vals, err := self.FieldValues(item, toInternal)
 
 	// Error ocurred, stop appending.
 	if err != nil {
 		return nil, err
 	}
 
+	columns = make(sqlgen.Columns, 0, len(cols))
+	for i := range cols {
+		columns = append(columns, sqlgen.Column{cols[i]})
+	}
+
+	arguments = make([]interface{}, 0, len(vals))
+	values = make(sqlgen.Values, 0, len(vals))
+	for i := range vals {
+		switch v := vals[i].(type) {
+		case sqlgen.Value:
+			// Adding value.
+			values = append(values, v)
+		default:
+			// Adding both value and placeholder.
+			values = append(values, sqlPlaceholder)
+			arguments = append(arguments, v)
+		}
+	}
+
 	row, err := self.source.doExec(sqlgen.Statement{
 		Type:    sqlgen.SqlInsert,
 		Table:   sqlgen.Table{self.tableN(0)},
 		Columns: columns,
 		Values:  values,
-	}, vals...)
+	}, arguments...)
 
 	if err != nil {
 		return nil, err
@@ -270,6 +282,42 @@ func toInternal(val interface{}) interface{} {
 		return t.Format(DateFormat)
 	case time.Duration:
 		return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond)
+	case sql.NullBool:
+		if t.Valid {
+			if t.Bool {
+				return toInternal(t.Bool)
+			} else {
+				return false
+			}
+		} else {
+			return sqlgen.Value{sqlgen.Raw{sqlNull}}
+		}
+	case sql.NullFloat64:
+		if t.Valid {
+			if t.Float64 != 0.0 {
+				return toInternal(t.Float64)
+			} else {
+				return float64(0)
+			}
+		} else {
+			return sqlgen.Value{sqlgen.Raw{sqlNull}}
+		}
+	case sql.NullInt64:
+		if t.Valid {
+			if t.Int64 != 0 {
+				return toInternal(t.Int64)
+			} else {
+				return 0
+			}
+		} else {
+			return sqlgen.Value{sqlgen.Raw{sqlNull}}
+		}
+	case sql.NullString:
+		if t.Valid {
+			return toInternal(t.String)
+		} else {
+			return sqlgen.Value{sqlgen.Raw{sqlNull}}
+		}
 	case bool:
 		if t == true {
 			return `1`
diff --git a/sqlite/database_test.go b/sqlite/database_test.go
index b34aa26b..438f81aa 100644
--- a/sqlite/database_test.go
+++ b/sqlite/database_test.go
@@ -215,6 +215,100 @@ func TestAppend(t *testing.T) {
 
 }
 
+// Attempts to test nullable fields.
+func TestNullableFields(t *testing.T) {
+	var err error
+	var sess db.Database
+	var col db.Collection
+	var id interface{}
+
+	if sess, err = db.Open(Adapter, settings); err != nil {
+		t.Fatal(err)
+	}
+
+	defer sess.Close()
+
+	type test_t struct {
+		Id              int64           `db:"id,omitempty"`
+		NullStringTest  sql.NullString  `db:"_string"`
+		NullInt64Test   sql.NullInt64   `db:"_int64"`
+		NullFloat64Test sql.NullFloat64 `db:"_float64"`
+		NullBoolTest    sql.NullBool    `db:"_bool"`
+	}
+
+	var test test_t
+
+	if col, err = sess.Collection(`data_types`); err != nil {
+		t.Fatal(err)
+	}
+
+	if err = col.Truncate(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Testing insertion of invalid nulls.
+	test = test_t{
+		NullStringTest:  sql.NullString{"", false},
+		NullInt64Test:   sql.NullInt64{0, false},
+		NullFloat64Test: sql.NullFloat64{0.0, false},
+		NullBoolTest:    sql.NullBool{false, false},
+	}
+	if id, err = col.Append(test_t{}); err != nil {
+		t.Fatal(err)
+	}
+
+	// Testing fetching of invalid nulls.
+	if err = col.Find(db.Cond{"id": id}).One(&test); err != nil {
+		t.Fatal(err)
+	}
+
+	if test.NullInt64Test.Valid {
+		t.Fatalf(`Expecting invalid null.`)
+	}
+	if test.NullFloat64Test.Valid {
+		t.Fatalf(`Expecting invalid null.`)
+	}
+	if test.NullBoolTest.Valid {
+		t.Fatalf(`Expecting invalid null.`)
+	}
+
+	// In PostgreSQL, how we can tell if this is an invalid null?
+
+	// if test.NullStringTest.Valid {
+	// 	t.Fatalf(`Expecting invalid null.`)
+	// }
+
+	// Testing insertion of valid nulls.
+	test = test_t{
+		NullStringTest:  sql.NullString{"", true},
+		NullInt64Test:   sql.NullInt64{0, true},
+		NullFloat64Test: sql.NullFloat64{0.0, true},
+		NullBoolTest:    sql.NullBool{false, true},
+	}
+	if id, err = col.Append(test); err != nil {
+		t.Fatal(err)
+	}
+
+	// Testing fetching of valid nulls.
+	if err = col.Find(db.Cond{"id": id}).One(&test); err != nil {
+		t.Fatal(err)
+	}
+
+	if test.NullInt64Test.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+	if test.NullFloat64Test.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+	if test.NullBoolTest.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+	if test.NullStringTest.Valid == false {
+		t.Fatalf(`Expecting valid value.`)
+	}
+
+}
+
 // Attempts to count all rows in our newly defined set.
 func TestResultCount(t *testing.T) {
 	var err error
diff --git a/sqlite/layout.go b/sqlite/layout.go
index 17f7c10a..bb750efe 100644
--- a/sqlite/layout.go
+++ b/sqlite/layout.go
@@ -121,4 +121,6 @@ const (
 	sqlDropTableLayout = `
 		DROP TABLE {{.Table}}
 	`
+
+	sqlNull = `NULL`
 )
-- 
GitLab