diff --git a/mysql/_dumps/structs.sql b/mysql/_dumps/structs.sql index a6046cbe0313439c64e9a05577a7f5f13e65ac2e..42240f4b4c6a875cb522d94c45d40fd6cedf70a8 100644 --- a/mysql/_dumps/structs.sql +++ b/mysql/_dumps/structs.sql @@ -33,16 +33,16 @@ DROP TABLE IF EXISTS data_types; CREATE TABLE data_types ( id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), - _uint INT(10) UNSIGNED NOT NULL DEFAULT 0, - _uint8 INT(10) UNSIGNED NOT NULL DEFAULT 0, - _uint16 INT(10) UNSIGNED NOT NULL DEFAULT 0, - _uint32 INT(10) UNSIGNED NOT NULL DEFAULT 0, - _uint64 INT(10) UNSIGNED NOT NULL DEFAULT 0, - _int INT(10) NOT NULL DEFAULT 0, - _int8 INT(10) NOT NULL DEFAULT 0, - _int16 INT(10) NOT NULL DEFAULT 0, - _int32 INT(10) NOT NULL DEFAULT 0, - _int64 INT(10) NOT NULL DEFAULT 0, + _uint INT(10) UNSIGNED DEFAULT 0, + _uint8 INT(10) UNSIGNED DEFAULT 0, + _uint16 INT(10) UNSIGNED DEFAULT 0, + _uint32 INT(10) UNSIGNED DEFAULT 0, + _uint64 INT(10) UNSIGNED DEFAULT 0, + _int INT(10) DEFAULT 0, + _int8 INT(10) DEFAULT 0, + _int16 INT(10) DEFAULT 0, + _int32 INT(10) DEFAULT 0, + _int64 INT(10) DEFAULT 0, _float32 DECIMAL(10,6), _float64 DECIMAL(10,6), _bool TINYINT(1), diff --git a/mysql/collection.go b/mysql/collection.go index 0091d2bce2a93cdf338a171ca61087ca3715e7ff..429f8c2f18747f0826a0ab638d22acbee78b23f7 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -27,6 +27,7 @@ import ( "strings" "time" + "database/sql" "menteslibres.net/gosexy/to" "upper.io/db" "upper.io/db/util/sqlgen" @@ -214,23 +215,34 @@ func (self *Table) Truncate() error { // Appends an item (map or struct) into the collection. func (self *Table) Append(item interface{}) (interface{}, error) { - - cols, vals, err := self.FieldValues(item, toInternal) - var columns sqlgen.Columns var values sqlgen.Values + var arguments []interface{} + var id int64 - for _, col := range cols { - columns = append(columns, sqlgen.Column{col}) + cols, vals, err := self.FieldValues(item, toInternal) + + if err != nil { + return nil, err } - for i := 0; i < len(vals); i++ { - values = append(values, sqlPlaceholder) + columns = make(sqlgen.Columns, 0, len(cols)) + for i := range cols { + columns = append(columns, sqlgen.Column{cols[i]}) } - // Error ocurred, stop appending. - if err != nil { - return nil, err + arguments = make([]interface{}, 0, len(vals)) + values = make(sqlgen.Values, 0, len(vals)) + for i := range vals { + switch v := vals[i].(type) { + case sqlgen.Value: + // Adding value. + values = append(values, v) + default: + // Adding both value and placeholder. + values = append(values, sqlPlaceholder) + arguments = append(arguments, v) + } } res, err := self.source.doExec(sqlgen.Statement{ @@ -238,13 +250,12 @@ func (self *Table) Append(item interface{}) (interface{}, error) { Table: sqlgen.Table{self.tableN(0)}, Columns: columns, Values: values, - }, vals...) + }, arguments...) if err != nil { return nil, err } - var id int64 id, _ = res.LastInsertId() return id, nil @@ -271,6 +282,42 @@ func toInternal(val interface{}) interface{} { return t.Format(DateFormat) case time.Duration: return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond) + case sql.NullBool: + if t.Valid { + if t.Bool { + return toInternal(t.Bool) + } else { + return false + } + } else { + return sqlgen.Value{sqlgen.Raw{mysqlNull}} + } + case sql.NullFloat64: + if t.Valid { + if t.Float64 != 0.0 { + return toInternal(t.Float64) + } else { + return float64(0) + } + } else { + return sqlgen.Value{sqlgen.Raw{mysqlNull}} + } + case sql.NullInt64: + if t.Valid { + if t.Int64 != 0 { + return toInternal(t.Int64) + } else { + return 0 + } + } else { + return sqlgen.Value{sqlgen.Raw{mysqlNull}} + } + case sql.NullString: + if t.Valid { + return toInternal(t.String) + } else { + return sqlgen.Value{sqlgen.Raw{mysqlNull}} + } case bool: if t == true { return `1` diff --git a/mysql/database_test.go b/mysql/database_test.go index 3a22d56be60fc9fb7691243cab307668f23614ee..b2baadb3cd5bcf6e6d3bffc31d94cc5a9ba758bb 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -227,6 +227,97 @@ func TestAppend(t *testing.T) { } +// Attempts to test nullable fields. +func TestNullableFields(t *testing.T) { + var err error + var sess db.Database + var col db.Collection + var id interface{} + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + type test_t struct { + Id int64 `db:"id,omitempty"` + NullStringTest sql.NullString `db:"_string"` + NullInt64Test sql.NullInt64 `db:"_int64"` + NullFloat64Test sql.NullFloat64 `db:"_float64"` + NullBoolTest sql.NullBool `db:"_bool"` + } + + var test test_t + + if col, err = sess.Collection(`data_types`); err != nil { + t.Fatal(err) + } + + if err = col.Truncate(); err != nil { + t.Fatal(err) + } + + // Testing insertion of invalid nulls. + test = test_t{ + NullStringTest: sql.NullString{"", false}, + NullInt64Test: sql.NullInt64{0, false}, + NullFloat64Test: sql.NullFloat64{0.0, false}, + NullBoolTest: sql.NullBool{false, false}, + } + if id, err = col.Append(test_t{}); err != nil { + t.Fatal(err) + } + + // Testing fetching of invalid nulls. + if err = col.Find(db.Cond{"id": id}).One(&test); err != nil { + t.Fatal(err) + } + + if test.NullInt64Test.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullFloat64Test.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullBoolTest.Valid { + t.Fatalf(`Expecting invalid null.`) + } + if test.NullStringTest.Valid { + t.Fatalf(`Expecting invalid null.`) + } + + // Testing insertion of valid nulls. + test = test_t{ + NullStringTest: sql.NullString{"", true}, + NullInt64Test: sql.NullInt64{0, true}, + NullFloat64Test: sql.NullFloat64{0.0, true}, + NullBoolTest: sql.NullBool{false, true}, + } + if id, err = col.Append(test); err != nil { + t.Fatal(err) + } + + // Testing fetching of valid nulls. + if err = col.Find(db.Cond{"id": id}).One(&test); err != nil { + t.Fatal(err) + } + + if test.NullInt64Test.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullFloat64Test.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullBoolTest.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + if test.NullStringTest.Valid == false { + t.Fatalf(`Expecting valid value.`) + } + +} + // Attempts to count all rows in our newly defined set. func TestResultCount(t *testing.T) { var err error diff --git a/mysql/layout.go b/mysql/layout.go index 98568de32765984fa693c692b4d4a6ed6a948bfb..21ac5e06322ea29ea0a10d34bda63c1113ce5b02 100644 --- a/mysql/layout.go +++ b/mysql/layout.go @@ -121,4 +121,6 @@ const ( mysqlDropTableLayout = ` DROP TABLE {{.Table}} ` + + mysqlNull = `NULL` )