diff --git a/db.go b/db.go index f3159cf1a5bd2c2957991c67f3cd96ae793c51dc..886408e2bfb7377a7feedac289fe80af8c2ccaf8 100644 --- a/db.go +++ b/db.go @@ -27,6 +27,7 @@ import ( "fmt" "github.com/gosexy/sugar" "strconv" + "regexp" "strings" "time" ) @@ -217,13 +218,47 @@ func (item Item) GetString(name string) string { func (item Item) GetDate(name string) time.Time { date := time.Date(0, time.January, 0, 0, 0, 0, 0, time.UTC) + switch item[name].(type) { case time.Time: date = item[name].(time.Time) + case string: + var matched bool + value := item[name].(string) + + matched, _ = regexp.MatchString(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`, value) + + if matched { + date, _ = time.Parse("2006-01-02 15:04:05", value) + } } return date } +func (item Item) GetDuration(name string) time.Duration { + duration, _ := time.ParseDuration("0h0m0s") + + switch item[name].(type) { + case time.Duration: + duration = item[name].(time.Duration) + case string: + var matched bool + var re *regexp.Regexp + value := item[name].(string) + + matched, _ = regexp.MatchString(`^\d{2}:\d{2}:\d{2}$`, value) + + if matched { + re, _ = regexp.Compile(`^(\d{2}):(\d{2}):(\d{2})$`) + all := re.FindAllStringSubmatch(value, -1) + + formatted := fmt.Sprintf("%sh%sm%ss", all[0][1], all[0][2], all[0][3]) + duration, _ = time.ParseDuration(formatted) + } + } + return duration +} + func (item Item) GetTuple(name string) sugar.Tuple { tuple := sugar.Tuple{} @@ -277,3 +312,36 @@ func (item Item) GetBool(name string) bool { return true } + +/* +func toInternal(val interface{}) interface{} { + + switch val.(type) { + case db.Id: + return bson.ObjectIdHex(string(val.(db.Id))) + case db.Item: + for k, _ := range val.(db.Item) { + val.(db.Item)[k] = toInternal(val.(db.Item)[k]) + } + } + + return val +} + +func toNative(val interface{}) interface{} { + + switch val.(type) { + case bson.M: + v2 := map[string]interface{}{} + for k, v := range val.(bson.M) { + v2[k] = toNative(v) + } + return v2 + case bson.ObjectId: + return db.Id(val.(bson.ObjectId).Hex()) + } + + return val + +} +*/ diff --git a/mysql/mysql.go b/mysql/mysql.go index aeb8b037dd920f7cb316f63179a1835945edcbbf..981696681102804c2bced591656c03248106a798 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -25,6 +25,7 @@ package mysql import ( _ "code.google.com/p/go-mysql-driver/mysql" + //_ "github.com/ziutek/mymysql/godrv" "database/sql" "fmt" "github.com/gosexy/db" @@ -33,8 +34,12 @@ import ( "regexp" "strconv" "strings" + "time" ) +const dateFormat = "2006-01-02 15:04:05.000000000" +const timeFormat = "%d:%02d:%02d.%09d" + type myQuery struct { Query []string SqlArgs []string @@ -145,8 +150,38 @@ func (t *MysqlTable) myFetchAll(rows sql.Rows) []db.Item { return items } +func toInternal(val interface{}) string { + + switch val.(type) { + case []byte: + return fmt.Sprintf("%s", string(val.([]byte))) + case time.Time: + return val.(time.Time).Format(dateFormat) + case time.Duration: + t := val.(time.Duration) + return fmt.Sprintf(timeFormat, int(t.Hours()), int(t.Minutes())%60, int(t.Seconds())%60, int(t.Nanoseconds())%1e9) + case bool: + if val.(bool) == true { + return "1" + } else { + return "0" + } + } + + return fmt.Sprintf("%v", val) +} + +func toNative(val interface{}) interface{} { + + switch val.(type) { + } + + return val + +} + // Executes a database/sql method. -func (my *MysqlDataSource) myExec(method string, terms ...interface{}) sql.Rows { +func (my *MysqlDataSource) myExec(method string, terms ...interface{}) (sql.Rows, error) { sn := reflect.ValueOf(my.session) fn := sn.MethodByName(method) @@ -169,10 +204,10 @@ func (my *MysqlDataSource) myExec(method string, terms ...interface{}) sql.Rows res := fn.Call(args) if res[1].IsNil() == false { - panic(res[1].Elem().Interface().(error)) + return sql.Rows{}, res[1].Elem().Interface().(error) } - return res[0].Elem().Interface().(sql.Rows) + return res[0].Elem().Interface().(sql.Rows), nil } // Represents a MySQL table. @@ -215,17 +250,18 @@ func (my *MysqlDataSource) Open() error { } conn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", my.config.User, my.config.Password, my.config.Host, my.config.Port, my.config.Database) + //conn := fmt.Sprintf("tcp:%s*%s/%s/%s", my.config.Host, my.config.Database, my.config.User, my.config.Password) my.session, err = sql.Open("mysql", conn) if err != nil { - return fmt.Errorf("Could not connect to %s", my.config.Host) + return err } return nil } -// Changes the active database. +// Changes the active database func (my *MysqlDataSource) Use(database string) error { my.config.Database = database my.session.Query(fmt.Sprintf("USE %s", database)) @@ -247,11 +283,14 @@ func (my *MysqlDataSource) Driver() interface{} { func (my *MysqlDataSource) Collections() []string { var collections []string var collection string - rows, _ := my.session.Query("SHOW TABLES") - for rows.Next() { - rows.Scan(&collection) - collections = append(collections, collection) + rows, err := my.session.Query("SHOW TABLES") + + if err == nil { + for rows.Next() { + rows.Scan(&collection) + collections = append(collections, collection) + } } return collections @@ -370,18 +409,18 @@ func (t *MysqlTable) marshal(where db.Cond) (string, []string) { } // Deletes all the rows in the table. -func (t *MysqlTable) Truncate() bool { +func (t *MysqlTable) Truncate() error { - t.parent.myExec( + _, err := t.parent.myExec( "Query", fmt.Sprintf("TRUNCATE TABLE %s", myTable(t.name)), ) - return false + return err } // Deletes all the rows in the table that match certain conditions. -func (t *MysqlTable) Remove(terms ...interface{}) bool { +func (t *MysqlTable) Remove(terms ...interface{}) error { conditions, cargs := t.compileConditions(terms) @@ -389,17 +428,17 @@ func (t *MysqlTable) Remove(terms ...interface{}) bool { conditions = "1 = 1" } - t.parent.myExec( + _, err := t.parent.myExec( "Query", fmt.Sprintf("DELETE FROM %s", myTable(t.name)), fmt.Sprintf("WHERE %s", conditions), cargs, ) - return true + return err } // Modifies all the rows in the table that match certain conditions. -func (t *MysqlTable) Update(terms ...interface{}) bool { +func (t *MysqlTable) Update(terms ...interface{}) error { var fields string var fargs db.SqlArgs @@ -418,13 +457,13 @@ func (t *MysqlTable) Update(terms ...interface{}) bool { conditions = "1 = 1" } - t.parent.myExec( + _, err := t.parent.myExec( "Query", fmt.Sprintf("UPDATE %s SET %s", myTable(t.name), fields), fargs, fmt.Sprintf("WHERE %s", conditions), cargs, ) - return true + return err } // Returns all the rows in the table that match certain conditions. @@ -465,7 +504,7 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []db.Item { conditions = "1 = 1" } - rows := t.parent.myExec( + rows, _ := t.parent.myExec( "Query", fmt.Sprintf("SELECT %s FROM %s", fields, myTable(t.name)), fmt.Sprintf("WHERE %s", conditions), args, @@ -595,7 +634,7 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []db.Item { } // Returns the number of rows in the current table that match certain conditions. -func (t *MysqlTable) Count(terms ...interface{}) int { +func (t *MysqlTable) Count(terms ...interface{}) (int, error) { terms = append(terms, db.Fields{"COUNT(1) AS _total"}) @@ -605,11 +644,11 @@ func (t *MysqlTable) Count(terms ...interface{}) int { response := result[0].Interface().([]db.Item) if len(response) > 0 { val, _ := strconv.Atoi(response[0]["_total"].(string)) - return val + return val, nil } } - return 0 + return 0, nil } // Returns the first row in the table that matches certain conditions. @@ -632,7 +671,7 @@ func (t *MysqlTable) Find(terms ...interface{}) db.Item { } // Inserts rows into the currently active table. -func (t *MysqlTable) Append(items ...interface{}) bool { +func (t *MysqlTable) Append(items ...interface{}) error { itop := len(items) @@ -645,10 +684,11 @@ func (t *MysqlTable) Append(items ...interface{}) bool { for field, value := range item.(db.Item) { fields = append(fields, field) - values = append(values, fmt.Sprintf("%v", value)) + //values = append(values, fmt.Sprintf("%v", value)) + values = append(values, toInternal(value)) } - t.parent.myExec("Query", + _, err := t.parent.myExec("Query", "INSERT INTO", myTable(t.name), myFields(fields), @@ -656,9 +696,10 @@ func (t *MysqlTable) Append(items ...interface{}) bool { myValues(values), ) + return err } - return true + return nil } // Returns a MySQL table structure by name. @@ -675,7 +716,7 @@ func (my *MysqlDataSource) Collection(name string) db.Collection { // Fetching table datatypes and mapping to internal gotypes. - rows := t.parent.myExec( + rows, _ := t.parent.myExec( "Query", "SHOW COLUMNS FROM", t.name, ) diff --git a/mysql/mysql_test.go b/mysql/mysql_test.go index d310a2130b0537824cc89e568e53dbf000ea6f9e..bb1b8a2190c2f4b9d6e6929b1ba5ba2e8e06b88d 100644 --- a/mysql/mysql_test.go +++ b/mysql/mysql_test.go @@ -4,16 +4,60 @@ import ( "database/sql" "fmt" "github.com/gosexy/db" + "github.com/gosexy/sugar" "github.com/kr/pretty" "math/rand" "testing" + "time" ) -const myHost = "10.0.0.11" +const myHost = "127.0.0.1" const myDatabase = "gotest" const myUser = "gouser" const myPassword = "gopass" +func getTestData() db.Item { + + _time, _ := time.ParseDuration("17h20m") + + data := db.Item{ + "_uint": uint(1), + "_uintptr": uintptr(1), + + "_uint8": uint8(1), + "_uint16": uint16(1), + "_uint32": uint32(1), + "_uint64": uint64(1), + + "_int": int(-1), + "_int8": int8(-1), + "_int16": int16(-1), + "_int32": int32(-1), + "_int64": int64(-1), + + "_float32": float32(1.0), + "_float64": float64(1.0), + + //"_complex64": complex64(1), + //"_complex128": complex128(1), + + "_byte": byte(1), + "_rune": rune(1), + + "_bool": bool(true), + "_string": string("abc"), + "_bytea": []byte{'a', 'b', 'c'}, + + //"_list": sugar.List{1, 2, 3}, + //"_map": sugar.Tuple{"a": 1, "b": 2, "c": 3}, + + "_date": time.Date(2012, 7, 28, 1, 2, 3, 0, time.UTC), + "_time": _time, + } + + return data +} + func TestMyTruncate(t *testing.T) { sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) @@ -30,7 +74,10 @@ func TestMyTruncate(t *testing.T) { for _, name := range collections { col := sess.Collection(name) col.Truncate() - if col.Count() != 0 { + + total, _ := col.Count() + + if total != 0 { t.Errorf("Could not truncate '%s'.", name) } } @@ -58,7 +105,9 @@ func TestMyAppend(t *testing.T) { col.Append(db.Item{"name": names[i]}) } - if col.Count() != len(names) { + total, _ := col.Count() + + if total != len(names) { t.Error("Could not append all items.") } @@ -244,3 +293,126 @@ func TestCustom(t *testing.T) { } } + +func TestDataTypes(t *testing.T) { + + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + + err := sess.Open() + + if err == nil { + defer sess.Close() + } + + col := sess.Collection("data_types") + + col.Truncate() + + data := getTestData() + + err = col.Append(data) + + if err != nil { + t.Errorf("Could not append test data.") + } + + // Getting and reinserting. + item := col.Find() + + err = col.Append(item) + + if err == nil { + t.Errorf("Expecting duplicated-key error.") + } + + delete(item, "id") + + err = col.Append(item) + + if err != nil { + t.Errorf("Could not append second element.") + } + + // Testing rows + items := col.FindAll() + + for i := 0; i < len(items); i++ { + + item := items[i] + + for key, _ := range item { + + switch key { + + // Signed integers. + case + "_int", + "_int8", + "_int16", + "_int32", + "_int64": + if item.GetInt(key) != int64(data["_int"].(int)) { + t.Errorf("Wrong datatype %v.", key) + } + + // Unsigned integers. + case + "_uint", + "_uintptr", + "_uint8", + "_uint16", + "_uint32", + "_uint64", + "_byte", + "_rune": + if item.GetInt(key) != int64(data["_uint"].(uint)) { + t.Errorf("Wrong datatype %v.", key) + } + + // Floating point. + case "_float32": + case "_float64": + if item.GetFloat(key) != data["_float64"].(float64) { + t.Errorf("Wrong datatype %v.", key) + } + + // Boolean + case "_bool": + if item.GetBool(key) != data["_bool"].(bool) { + t.Errorf("Wrong datatype %v.", key) + } + + // String + case "_string": + if item.GetString(key) != data["_string"].(string) { + t.Errorf("Wrong datatype %v.", key) + } + + // Map + case "_map": + if item.GetTuple(key)["a"] != data["_map"].(sugar.Tuple)["a"] { + t.Errorf("Wrong datatype %v.", key) + } + + // Array + case "_list": + if item.GetList(key)[0] != data["_list"].(sugar.List)[0] { + t.Errorf("Wrong datatype %v.", key) + } + + // Time + case "_time": + if item.GetDuration(key).String() != data["_time"].(time.Duration).String() { + t.Errorf("Wrong datatype %v.", key) + } + + // Date + case "_date": + if item.GetDate(key).Equal(data["_date"].(time.Time)) == false { + t.Errorf("Wrong datatype %v.", key) + } + } + } + } + +}