From 0d62d3d930143a7f754a401e78913f614a81b133 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <xiam@menteslibres.org>
Date: Sat, 1 Dec 2012 20:42:41 -0600
Subject: [PATCH] Adding ExistentCollection() and changing Collection().

---
 db.go               |   3 +-
 mongo/mongo.go      |  28 +++++++++---
 mongo/mongo_test.go | 107 +++++++++++++++++++++++---------------------
 3 files changed, 81 insertions(+), 57 deletions(-)

diff --git a/db.go b/db.go
index eab82db0..03b06073 100644
--- a/db.go
+++ b/db.go
@@ -181,7 +181,8 @@ type Database interface {
 	Open() error
 	Close() error
 
-	Collection(string) Collection
+	Collection(string) (Collection, error)
+	ExistentCollection(string) Collection
 	Collections() []string
 
 	Use(string) error
diff --git a/mongo/mongo.go b/mongo/mongo.go
index 9cfaa74d..4a12f345 100644
--- a/mongo/mongo.go
+++ b/mongo/mongo.go
@@ -43,7 +43,7 @@ func init() {
 
 // Session
 type MongoDataSource struct {
-	name string
+	name     string
 	config   db.DataSource
 	session  *mgo.Session
 	database *mgo.Database
@@ -51,7 +51,7 @@ type MongoDataSource struct {
 
 // Collection
 type MongoDataSourceCollection struct {
-	name string
+	name       string
 	parent     *MongoDataSource
 	collection *mgo.Collection
 }
@@ -433,7 +433,7 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []db.Item {
 	// This query is related to other collections.
 	if relate != nil {
 		for rname, rterms := range relate.(db.Relate) {
-			rcollection := c.parent.Collection(rname)
+			rcollection, _ := c.parent.Collection(rname)
 
 			ttop := len(rterms)
 			for t := ttop - 1; t >= 0; t-- {
@@ -450,7 +450,7 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []db.Item {
 
 	if relateAll != nil {
 		for rname, rterms := range relateAll.(db.RelateAll) {
-			rcollection := c.parent.Collection(rname)
+			rcollection, _ := c.parent.Collection(rname)
 
 			ttop := len(rterms)
 			for t := ttop - 1; t >= 0; t-- {
@@ -556,12 +556,28 @@ func (m *MongoDataSource) Exists() bool {
 }
 
 // Returns a collection from the current database.
-func (m *MongoDataSource) Collection(name string) db.Collection {
+func (m *MongoDataSource) Collection(name string) (db.Collection, error) {
+	var err error
+
 	c := &MongoDataSourceCollection{}
 	c.parent = m
 	c.name = name
 	c.collection = m.database.C(name)
-	return c
+
+	if c.Exists() == false {
+		err = fmt.Errorf("Collection %s does not exists.", name)
+	}
+
+	return c, err
+}
+
+// Returns a collection from the current database.
+func (self *MongoDataSource) ExistentCollection(name string) db.Collection {
+	col, err := self.Collection(name)
+	if err != nil {
+		panic(err.Error())
+	}
+	return col
 }
 
 // Returns the underlying driver (*mgo.Session).
diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go
index 12d6248e..eadcb4ca 100644
--- a/mongo/mongo_test.go
+++ b/mongo/mongo_test.go
@@ -104,10 +104,11 @@ func TestAppend(t *testing.T) {
 
 	defer sess.Close()
 
-	col := sess.Collection("people")
+	col, _ := sess.Collection("people")
 
 	if col.Exists() == true {
 		t.Errorf("Collection should not exists, yet.")
+		return
 	}
 
 	names := []string{"Juan", "José", "Pedro", "María", "Roberto", "Manuel", "Miguel"}
@@ -117,7 +118,8 @@ func TestAppend(t *testing.T) {
 	}
 
 	if col.Exists() == false {
-		t.Errorf("Collection should not exists.")
+		t.Errorf("Collection should exists.")
+		return
 	}
 
 	count, err := col.Count()
@@ -143,9 +145,9 @@ func TestFind(t *testing.T) {
 
 	defer sess.Close()
 
-	col := sess.Collection("people")
+	people, _ := sess.Collection("people")
 
-	result := col.Find(db.Cond{"name": "José"})
+	result := people.Find(db.Cond{"name": "José"})
 
 	if result["name"] != "José" {
 		t.Error("Could not find a recently appended item.")
@@ -164,15 +166,15 @@ func TestDelete(t *testing.T) {
 
 	defer sess.Close()
 
-	col := sess.Collection("people")
+	people := sess.ExistentCollection("people")
 
-	err = col.Remove(db.Cond{"name": "Juan"})
+	err = people.Remove(db.Cond{"name": "Juan"})
 
 	if err != nil {
 		t.Error("Failed to remove.")
 	}
 
-	result := col.Find(db.Cond{"name": "Juan"})
+	result := people.Find(db.Cond{"name": "Juan"})
 
 	if len(result) > 0 {
 		t.Error("Could not remove a recently appended item.")
@@ -189,15 +191,15 @@ func TestUpdate(t *testing.T) {
 
 	defer sess.Close()
 
-	col := sess.Collection("people")
+	people, _ := sess.Collection("people")
 
-	err = col.Update(db.Cond{"name": "José"}, db.Set{"name": "Joseph"})
+	err = people.Update(db.Cond{"name": "José"}, db.Set{"name": "Joseph"})
 
 	if err != nil {
 		t.Error("Failed to update collection.")
 	}
 
-	result := col.Find(db.Cond{"name": "Joseph"})
+	result := people.Find(db.Cond{"name": "Joseph"})
 
 	if len(result) == 0 {
 		t.Error("Could not update a recently appended item.")
@@ -205,7 +207,6 @@ func TestUpdate(t *testing.T) {
 }
 
 func TestPopulate(t *testing.T) {
-	var i int
 
 	sess, err := db.Open("mongo", db.DataSource{Host: host, Database: dbname})
 
@@ -216,43 +217,48 @@ func TestPopulate(t *testing.T) {
 
 	defer sess.Close()
 
-	places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"}
+	people, _ := sess.Collection("people")
+	places, _ := sess.Collection("places")
+	children, _ := sess.Collection("children")
+	visits, _ := sess.Collection("visits")
+
+	values := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"}
 
-	for i = 0; i < len(places); i++ {
-		sess.Collection("places").Append(db.Item{
+	for i, value := range values {
+		places.Append(db.Item{
 			"code_id": i,
-			"name":    places[i],
+			"name":    value,
 		})
 	}
 
-	people := sess.Collection("people").FindAll(
+	results := people.FindAll(
 		db.Fields{"id", "name"},
 		db.Sort{"name": "ASC", "id": -1},
 	)
 
-	for i = 0; i < len(people); i++ {
-		person := people[i]
+	for _, person := range results {
 
 		// Has 5 children.
+
 		for j := 0; j < 5; j++ {
-			sess.Collection("children").Append(db.Item{
+			children.Append(db.Item{
 				"name":      fmt.Sprintf("%s's child %d", person["name"], j+1),
 				"parent_id": person["_id"],
 			})
 		}
 
 		// Lives in
-		sess.Collection("people").Update(
+		people.Update(
 			db.Cond{"_id": person["_id"]},
-			db.Set{"place_code_id": int(rand.Float32() * float32(len(places)))},
+			db.Set{"place_code_id": int(rand.Float32() * float32(len(results)))},
 		)
 
 		// Has visited
 		for k := 0; k < 3; k++ {
-			place := sess.Collection("places").Find(db.Cond{
-				"code_id": int(rand.Float32() * float32(len(places))),
+			place := places.Find(db.Cond{
+				"code_id": int(rand.Float32() * float32(len(results))),
 			})
-			sess.Collection("visits").Append(db.Item{
+			visits.Append(db.Item{
 				"place_id":  place["_id"],
 				"person_id": person["_id"],
 			})
@@ -271,26 +277,29 @@ func TestRelation(t *testing.T) {
 
 	defer sess.Close()
 
-	col := sess.Collection("people")
+	people, _ := sess.Collection("people")
+	places, _ := sess.Collection("places")
+	children, _ := sess.Collection("children")
+	visits, _ := sess.Collection("visits")
 
-	result := col.FindAll(
+	result := people.FindAll(
 		db.Relate{
 			"lives_in": db.On{
-				sess.Collection("places"),
+				places,
 				db.Cond{"code_id": "{place_code_id}"},
 			},
 		},
 		db.RelateAll{
 			"has_children": db.On{
-				sess.Collection("children"),
+				children,
 				db.Cond{"parent_id": "{_id}"},
 			},
 			"has_visited": db.On{
-				sess.Collection("visits"),
+				visits,
 				db.Cond{"person_id": "{_id}"},
 				db.Relate{
 					"place": db.On{
-						sess.Collection("places"),
+						places,
 						db.Cond{"_id": "{place_id}"},
 					},
 				},
@@ -312,19 +321,19 @@ func TestDataTypes(t *testing.T) {
 
 	defer sess.Close()
 
-	col := sess.Collection("data_types")
+	dataTypes, _ := sess.Collection("data_types")
 
-	col.Truncate()
+	dataTypes.Truncate()
 
-	data := testItem()
+	testData := testItem()
 
-	ids, err := col.Append(data)
+	ids, err := dataTypes.Append(testData)
 
 	if err != nil {
 		t.Errorf("Could not append test data.")
 	}
 
-	found, _ := col.Count(db.Cond{"_id": db.Id(ids[0])})
+	found, _ := dataTypes.Count(db.Cond{"_id": db.Id(ids[0])})
 
 	if found == 0 {
 		t.Errorf("Cannot find recently inserted item (by ID).")
@@ -332,9 +341,9 @@ func TestDataTypes(t *testing.T) {
 
 	// Getting and reinserting.
 
-	item := col.Find()
+	item := dataTypes.Find()
 
-	_, err = col.Append(item)
+	_, err = dataTypes.Append(item)
 
 	if err == nil {
 		t.Errorf("Expecting duplicated-key error.")
@@ -342,7 +351,7 @@ func TestDataTypes(t *testing.T) {
 
 	delete(item, "_id")
 
-	_, err = col.Append(item)
+	_, err = dataTypes.Append(item)
 
 	if err != nil {
 		t.Errorf("Could not append second element.")
@@ -350,11 +359,9 @@ func TestDataTypes(t *testing.T) {
 
 	// Testing rows
 
-	items := col.FindAll()
-
-	for i := 0; i < len(items); i++ {
+	results := dataTypes.FindAll()
 
-		item := items[i]
+	for _, item := range results {
 
 		for key, _ := range item {
 
@@ -367,7 +374,7 @@ func TestDataTypes(t *testing.T) {
 				"_int16",
 				"_int32",
 				"_int64":
-				if item.GetInt(key) != int64(data["_int"].(int)) {
+				if item.GetInt(key) != int64(testData["_int"].(int)) {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
@@ -381,44 +388,44 @@ func TestDataTypes(t *testing.T) {
 				"_uint64",
 				"_byte",
 				"_rune":
-				if item.GetInt(key) != int64(data["_uint"].(uint)) {
+				if item.GetInt(key) != int64(testData["_uint"].(uint)) {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
 			// Floating point.
 			case "_float32":
 			case "_float64":
-				if item.GetFloat(key) != data["_float64"].(float64) {
+				if item.GetFloat(key) != testData["_float64"].(float64) {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
 			// Boolean
 			case "_bool":
-				if item.GetBool(key) != data["_bool"].(bool) {
+				if item.GetBool(key) != testData["_bool"].(bool) {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
 			// String
 			case "_string":
-				if item.GetString(key) != data["_string"].(string) {
+				if item.GetString(key) != testData["_string"].(string) {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
 			// Map
 			case "_map":
-				if item.GetTuple(key)["a"] != data["_map"].(sugar.Tuple)["a"] {
+				if item.GetTuple(key)["a"] != testData["_map"].(sugar.Tuple)["a"] {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
 			// Array
 			case "_list":
-				if item.GetList(key)[0] != data["_list"].(sugar.List)[0] {
+				if item.GetList(key)[0] != testData["_list"].(sugar.List)[0] {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 
 			// Date
 			case "_date":
-				if item.GetDate(key).Equal(data["_date"].(time.Time)) == false {
+				if item.GetDate(key).Equal(testData["_date"].(time.Time)) == false {
 					t.Errorf("Wrong datatype %v.", key)
 				}
 			}
-- 
GitLab