diff --git a/db.go b/db.go index 742c7827e6503cdf943d95910c6ef71d1cf8c13f..f3159cf1a5bd2c2957991c67f3cd96ae793c51dc 100644 --- a/db.go +++ b/db.go @@ -23,6 +23,14 @@ package db +import ( + "fmt" + "github.com/gosexy/sugar" + "strconv" + "strings" + "time" +) + // Handles conditions and operators in an expression. // // Examples: @@ -154,6 +162,8 @@ type Upsert map[string]interface{} // Rows from a result. type Item map[string]interface{} +type Id string + // Connection and authentication data. type DataSource struct { Host string @@ -179,18 +189,18 @@ type Database interface { // Collection methods. type Collection interface { - Append(...interface{}) bool + Append(...interface{}) error - Count(...interface{}) int + Count(...interface{}) (int, error) Find(...interface{}) Item FindAll(...interface{}) []Item - Update(...interface{}) bool + Update(...interface{}) error - Remove(...interface{}) bool + Remove(...interface{}) error - Truncate() bool + Truncate() error } // Specifies which fields to return in a query. @@ -200,3 +210,70 @@ type Fields []string type MultiFlag bool type SqlValues []string type SqlArgs []string + +func (item Item) GetString(name string) string { + return fmt.Sprintf("%v", item[name]) +} + +func (item Item) GetDate(name string) time.Time { + date := time.Date(0, time.January, 0, 0, 0, 0, 0, time.UTC) + switch item[name].(type) { + case time.Time: + date = item[name].(time.Time) + } + return date +} + +func (item Item) GetTuple(name string) sugar.Tuple { + tuple := sugar.Tuple{} + + switch item[name].(type) { + case map[string]interface{}: + for k, _ := range item[name].(map[string]interface{}) { + tuple[k] = item[name].(map[string]interface{})[k] + } + case sugar.Tuple: + tuple = item[name].(sugar.Tuple) + } + + return tuple +} + +func (item Item) GetList(name string) sugar.List { + list := sugar.List{} + + switch item[name].(type) { + case []interface{}: + list = make(sugar.List, len(item[name].([]interface{}))) + + for k, _ := range item[name].([]interface{}) { + list[k] = item[name].([]interface{})[k] + } + } + + return list +} + +func (item Item) GetInt(name string) int64 { + i, _ := strconv.ParseInt(fmt.Sprintf("%v", item[name]), 10, 64) + return i +} + +func (item Item) GetFloat(name string) float64 { + f, _ := strconv.ParseFloat(fmt.Sprintf("%v", item[name]), 64) + return f +} + +func (item Item) GetBool(name string) bool { + + if item[name] == nil { + return false + } else { + b := strings.ToLower(fmt.Sprintf("%v", item[name])) + if b == "" || b == "0" || b == "false" { + return false + } + } + + return true +} diff --git a/mongo/mongo.go b/mongo/mongo.go index 8408b2e5f10e4a59b380e04a52a083bb7c677504..9068dc6bc59176ff79b3ec5c0f0432a89bcba98a 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -29,6 +29,7 @@ import ( "github.com/gosexy/sugar" "labix.org/v2/mgo" "labix.org/v2/mgo/bson" + "log" "net/url" "reflect" "regexp" @@ -60,7 +61,15 @@ func (c *MongoDataSourceCollection) marshal(where db.Cond) map[string]interface{ if len(chunks) >= 2 { conds[chunks[0]] = map[string]interface{}{chunks[1]: val} } else { - conds[key] = val + conds[key] = toInternal(val) + /* + switch val.(type) { + case db.Id: + conds[key] = bson.ObjectIdHex(string(val.(db.Id))) + default: + conds[key] = val + } + */ } } @@ -69,18 +78,18 @@ func (c *MongoDataSourceCollection) marshal(where db.Cond) map[string]interface{ } // Deletes the whole collection. -func (c *MongoDataSourceCollection) Truncate() bool { +func (c *MongoDataSourceCollection) Truncate() error { err := c.collection.DropCollection() - if err == nil { - return false + if err != nil { + return err } - return true + return nil } // Inserts items into the collection. -func (c *MongoDataSourceCollection) Append(items ...interface{}) bool { +func (c *MongoDataSourceCollection) Append(items ...interface{}) error { parent := reflect.TypeOf(c.collection) method, _ := parent.MethodByName("Insert") @@ -89,17 +98,18 @@ func (c *MongoDataSourceCollection) Append(items ...interface{}) bool { args[0] = reflect.ValueOf(c.collection) itop := len(items) + for i := 0; i < itop; i++ { - args[i+1] = reflect.ValueOf(items[i]) + args[i+1] = reflect.ValueOf(toInternal(items[i])) } exec := method.Func.Call(args) if exec[0].Interface() != nil { - return false + return exec[0].Interface().(error) } - return true + return nil } // Compiles terms into conditions that mgo can understand. @@ -160,17 +170,17 @@ func (c *MongoDataSourceCollection) compileQuery(terms []interface{}) interface{ } // Removes all the items that match the provided conditions. -func (c *MongoDataSourceCollection) Remove(terms ...interface{}) bool { +func (c *MongoDataSourceCollection) Remove(terms ...interface{}) error { query := c.compileQuery(terms) - c.collection.RemoveAll(query) + _, err := c.collection.RemoveAll(query) - return true + return err } // Updates all the items that match the provided conditions. You can specify the modification type by using db.Set, db.Modify or db.Upsert. -func (c *MongoDataSourceCollection) Update(terms ...interface{}) bool { +func (c *MongoDataSourceCollection) Update(terms ...interface{}) error { var set interface{} var upsert interface{} @@ -197,22 +207,24 @@ func (c *MongoDataSourceCollection) Update(terms ...interface{}) bool { } } + var err error + if set != nil { - c.collection.UpdateAll(query, db.Item{"$set": set}) - return true + _, err = c.collection.UpdateAll(query, db.Item{"$set": set}) + return err } if modify != nil { - c.collection.UpdateAll(query, modify) - return true + _, err = c.collection.UpdateAll(query, modify) + return err } if upsert != nil { - c.collection.Upsert(query, upsert) - return true + _, err = c.collection.Upsert(query, upsert) + return err } - return false + return nil } // Calls a MongoDataSourceCollection function by string. @@ -235,19 +247,19 @@ func (c *MongoDataSourceCollection) invoke(fn string, terms []interface{}) []ref return exec } +func (c *MongoDataSourceCollection) Error(err error) { + log.Printf("%s: %s\n", c.collection.FullName, err) +} + // Returns the number of total items matching the provided conditions. -func (c *MongoDataSourceCollection) Count(terms ...interface{}) int { +func (c *MongoDataSourceCollection) Count(terms ...interface{}) (int, error) { q := c.invoke("BuildQuery", terms) p := q[0].Interface().(*mgo.Query) count, err := p.Count() - if err != nil { - panic(err) - } - - return count + return count, err } // Returns a document that matches all the provided conditions. db.Ordering of the terms doesn't matter but you must take in @@ -316,6 +328,37 @@ func (c *MongoDataSourceCollection) BuildQuery(terms ...interface{}) *mgo.Query return q } +func toInternal(val interface{}) interface{} { + + switch val.(type) { + case db.Id: + return bson.ObjectIdHex(string(val.(db.Id))) + case db.Item: + for k, _ := range val.(db.Item) { + val.(db.Item)[k] = toInternal(val.(db.Item)[k]) + } + } + + return val +} + +func toNative(val interface{}) interface{} { + + switch val.(type) { + case bson.M: + v2 := map[string]interface{}{} + for k, v := range val.(bson.M) { + v2[k] = toNative(v) + } + return v2 + case bson.ObjectId: + return db.Id(val.(bson.ObjectId).Hex()) + } + + return val + +} + // Returns all the results that match the provided conditions. See Find(). func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []db.Item { var items []db.Item @@ -397,7 +440,7 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []db.Item { // Default values. for key, val := range result[i].(bson.M) { - item[key] = val + item[key] = toNative(val) } // Querying relations diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index cd3226f988c8e7cee97ac414210fe161e4d77b75..ecd90b3a8928f950c53cefc286be89c70153f1af 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -3,17 +3,56 @@ package mongo import ( "fmt" "github.com/gosexy/db" + "github.com/gosexy/sugar" "github.com/kr/pretty" "math/rand" "testing" + "time" ) -const mgHost = "10.0.0.11" +const mgHost = "mongodbhost" const mgDatabase = "gotest" +func getTestData() db.Item { + data := db.Item{ + "_uint": uint(1), + "_uintptr": uintptr(1), + + "_uint8": uint8(1), + "_uint16": uint16(1), + "_uint32": uint32(1), + "_uint64": uint64(1), + + "_int": int(-1), + "_int8": int8(-1), + "_int16": int16(-1), + "_int32": int32(-1), + "_int64": int64(-1), + + "_float32": float32(1.0), + "_float64": float64(1.0), + + //"_complex64": complex64(1), + //"_complex128": complex128(1), + + "_byte": byte(1), + "_rune": rune(1), + + "_bool": bool(true), + "_string": string("abc"), + + "_list": sugar.List{1, 2, 3}, + "_map": sugar.Tuple{"a": 1, "b": 2, "c": 3}, + + "_date": time.Date(2012, 7, 28, 0, 0, 0, 0, time.UTC), + } + + return data +} + func TestMgOpen(t *testing.T) { - sess := Session(db.DataSource{Host: "0.0.0.0"}) + sess := Session(db.DataSource{Host: "1.1.1.1"}) err := sess.Open() defer sess.Close() @@ -74,7 +113,13 @@ func TestMgAppend(t *testing.T) { col.Append(db.Item{"name": names[i]}) } - if col.Count() != len(names) { + count, err := col.Count() + + if err != nil { + t.Error("Failed to count on collection.") + } + + if count != len(names) { t.Error("Could not append all items.") } @@ -102,9 +147,11 @@ func TestMgFind(t *testing.T) { } func TestMgDelete(t *testing.T) { + var err error + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) - err := sess.Open() + err = sess.Open() defer sess.Close() if err != nil { @@ -113,7 +160,11 @@ func TestMgDelete(t *testing.T) { col := sess.Collection("people") - col.Remove(db.Cond{"name": "Juan"}) + err = col.Remove(db.Cond{"name": "Juan"}) + + if err != nil { + t.Error("Failed to remove.") + } result := col.Find(db.Cond{"name": "Juan"}) @@ -123,9 +174,11 @@ func TestMgDelete(t *testing.T) { } func TestMgUpdate(t *testing.T) { + var err error + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) - err := sess.Open() + err = sess.Open() defer sess.Close() if err != nil { @@ -134,7 +187,11 @@ func TestMgUpdate(t *testing.T) { col := sess.Collection("people") - col.Update(db.Cond{"name": "José"}, db.Set{"name": "Joseph"}) + err = col.Update(db.Cond{"name": "José"}, db.Set{"name": "Joseph"}) + + if err != nil { + t.Error("Failed to update collection.") + } result := col.Find(db.Cond{"name": "Joseph"}) @@ -236,3 +293,122 @@ func TestMgRelation(t *testing.T) { fmt.Printf("%# v\n", pretty.Formatter(result)) } + +func TestDataTypes(t *testing.T) { + + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + + if err == nil { + defer sess.Close() + } + + col := sess.Collection("data_types") + + col.Truncate() + + data := getTestData() + + err = col.Append(data) + + if err != nil { + t.Errorf("Could not append test data.") + } + + // Getting and reinserting. + + item := col.Find() + + err = col.Append(item) + + if err == nil { + t.Errorf("Expecting duplicated-key error.") + } + + delete(item, "_id") + + err = col.Append(item) + + if err != nil { + t.Errorf("Could not append second element.") + } + + // Testing rows + + items := col.FindAll() + + for i := 0; i < len(items); i++ { + + item := items[i] + + for key, _ := range item { + + switch key { + + // Signed integers. + case + "_int", + "_int8", + "_int16", + "_int32", + "_int64": + if item.GetInt(key) != int64(data["_int"].(int)) { + t.Errorf("Wrong datatype %v.", key) + } + + // Unsigned integers. + case + "_uint", + "_uintptr", + "_uint8", + "_uint16", + "_uint32", + "_uint64", + "_byte", + "_rune": + if item.GetInt(key) != int64(data["_uint"].(uint)) { + t.Errorf("Wrong datatype %v.", key) + } + + // Floating point. + case "_float32": + case "_float64": + if item.GetFloat(key) != data["_float64"].(float64) { + t.Errorf("Wrong datatype %v.", key) + } + + // Boolean + case "_bool": + if item.GetBool(key) != data["_bool"].(bool) { + t.Errorf("Wrong datatype %v.", key) + } + + // String + case "_string": + if item.GetString(key) != data["_string"].(string) { + t.Errorf("Wrong datatype %v.", key) + } + + // Map + case "_map": + if item.GetTuple(key)["a"] != data["_map"].(sugar.Tuple)["a"] { + t.Errorf("Wrong datatype %v.", key) + } + + // Array + case "_list": + if item.GetList(key)[0] != data["_list"].(sugar.List)[0] { + t.Errorf("Wrong datatype %v.", key) + } + + // Date + case "_date": + if item.GetDate(key).Equal(data["_date"].(time.Time)) == false { + t.Errorf("Wrong datatype %v.", key) + } + } + } + } + +}