From 3bad2a8555605e5b6f76f0fe57967f238c5d2bfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Mon, 4 Aug 2014 19:31:30 -0500 Subject: [PATCH] SQLite: Adding support for nullable fields. Issue: #26. --- sqlite/collection.go | 70 +++++++++++++++++++++++++----- sqlite/database_test.go | 94 +++++++++++++++++++++++++++++++++++++++++ sqlite/layout.go | 2 + 3 files changed, 155 insertions(+), 11 deletions(-) diff --git a/sqlite/collection.go b/sqlite/collection.go index 5139885e..3269d91d 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -27,6 +27,7 @@ import ( "strings" "time" + "database/sql" "menteslibres.net/gosexy/to" "upper.io/db" "upper.io/db/util/sqlgen" @@ -213,31 +214,42 @@ func (self *table) Truncate() error { // Appends an item (map or struct) into the collection. func (self *table) Append(item interface{}) (interface{}, error) { - - cols, vals, err := self.FieldValues(item, toInternal) - + var arguments []interface{} var columns sqlgen.Columns var values sqlgen.Values - for _, col := range cols { - columns = append(columns, sqlgen.Column{col}) - } - - for i := 0; i < len(vals); i++ { - values = append(values, sqlPlaceholder) - } + cols, vals, err := self.FieldValues(item, toInternal) // Error ocurred, stop appending. if err != nil { return nil, err } + columns = make(sqlgen.Columns, 0, len(cols)) + for i := range cols { + columns = append(columns, sqlgen.Column{cols[i]}) + } + + arguments = make([]interface{}, 0, len(vals)) + values = make(sqlgen.Values, 0, len(vals)) + for i := range vals { + switch v := vals[i].(type) { + case sqlgen.Value: + // Adding value. + values = append(values, v) + default: + // Adding both value and placeholder. + values = append(values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + row, err := self.source.doExec(sqlgen.Statement{ Type: sqlgen.SqlInsert, Table: sqlgen.Table{self.tableN(0)}, Columns: columns, Values: values, - }, vals...) + }, arguments...) if err != nil { return nil, err @@ -270,6 +282,42 @@ func toInternal(val interface{}) interface{} { return t.Format(DateFormat) case time.Duration: return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond) + case sql.NullBool: + if t.Valid { + if t.Bool { + return toInternal(t.Bool) + } else { + return false + } + } else { + return sqlgen.Value{sqlgen.Raw{sqlNull}} + } + case sql.NullFloat64: + if t.Valid { + if t.Float64 != 0.0 { + return toInternal(t.Float64) + } else { + return float64(0) + } + } else { + return sqlgen.Value{sqlgen.Raw{sqlNull}} + } + case sql.NullInt64: + if t.Valid { + if t.Int64 != 0 { + return toInternal(t.Int64) + } else { + return 0 + } + } else { + return sqlgen.Value{sqlgen.Raw{sqlNull}} + } + case sql.NullString: + if t.Valid { + return toInternal(t.String) + } else { + return sqlgen.Value{sqlgen.Raw{sqlNull}} + } case bool: if t == true { return `1` diff --git a/sqlite/database_test.go b/sqlite/database_test.go index b34aa26b..438f81aa 100644 --- a/sqlite/database_test.go +++ b/sqlite/database_test.go @@ -215,6 +215,100 @@ func TestAppend(t *testing.T) { } +// Attempts to test nullable fields. +func TestNullableFields(t *testing.T) { + var err error + var sess db.Database + var col db.Collection + var id interface{} + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + type test_t struct { + Id int64 `db:"id,omitempty"` + NullStringTest sql.NullString `db:"_string"` + NullInt64Test sql.NullInt64 `db:"_int64"` + NullFloat64Test sql.NullFloat64 `db:"_float64"` + NullBoolTest sql.NullBool `db:"_bool"` + } + + var test test_t + + if col, err = sess.Collection(`data_types`); err != nil { + t.Fatal(err) + } + + if err = col.Truncate(); err != nil { + t.Fatal(err) + } + + // Testing insertion of invalid nulls. + test = test_t{ + NullStringTest: sql.NullString{"", false}, + NullInt64Test: sql.NullInt64{0, false}, + NullFloat64Test: sql.NullFloat64{0.0, false}, + NullBoolTest: sql.NullBool{false, false}, + } + if id, err = col.Append(test_t{}); err != nil { + t.Fatal(err) + } + + // Testing fetching of invalid nulls. + if err = col.Find(db.Cond{"id": id}).One(&test); err != nil { + t.Fatal(err) + } + + if test.NullInt64Test.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullFloat64Test.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullBoolTest.Valid { + t.Fatalf(`Expecting invalid null.`) + } + + // In PostgreSQL, how we can tell if this is an invalid null? + + // if test.NullStringTest.Valid { + // t.Fatalf(`Expecting invalid null.`) + // } + + // Testing insertion of valid nulls. + test = test_t{ + NullStringTest: sql.NullString{"", true}, + NullInt64Test: sql.NullInt64{0, true}, + NullFloat64Test: sql.NullFloat64{0.0, true}, + NullBoolTest: sql.NullBool{false, true}, + } + if id, err = col.Append(test); err != nil { + t.Fatal(err) + } + + // Testing fetching of valid nulls. + if err = col.Find(db.Cond{"id": id}).One(&test); err != nil { + t.Fatal(err) + } + + if test.NullInt64Test.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullFloat64Test.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullBoolTest.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullStringTest.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + +} + // Attempts to count all rows in our newly defined set. func TestResultCount(t *testing.T) { var err error diff --git a/sqlite/layout.go b/sqlite/layout.go index 17f7c10a..bb750efe 100644 --- a/sqlite/layout.go +++ b/sqlite/layout.go @@ -121,4 +121,6 @@ const ( sqlDropTableLayout = ` DROP TABLE {{.Table}} ` + + sqlNull = `NULL` ) -- GitLab