diff --git a/db/README.md b/db/README.md index d35a6400df25d3e5d6e7f4897dc6dbfb77b3d16c..06e98aa011f890cc4cbde2b3f0c778f350b2d1b2 100644 --- a/db/README.md +++ b/db/README.md @@ -2,6 +2,8 @@ This package is a wrapper of [mgo](http://launchpad.net/mgo), [database/sql](http://golang.org/pkg/database/sql) and some of its database drivers friends, the goal of this abstraction is to provide a common, simplified, consistent layer for working with different databases using Go. +**IMPORTANT:** Recent changes have rendered this documentation inaccurate, please wait until I can review and update it. + ## Installation Please read docs on the [gosexy](https://github.com/xiam/gosexy) package before rushing to install ``gosexy/db`` diff --git a/db/db.go b/db/db.go index 0dc5ca5dd04fcee36b55cea41af24d5c10fa980a..e5b85e356c9d00181555087df62320562923f396 100644 --- a/db/db.go +++ b/db/db.go @@ -197,6 +197,6 @@ type Collection interface { type Fields []string // Specifies single or multiple requests in FindAll() expressions. -type multiFlag bool -type sqlValues []string -type sqlArgs []string +type MultiFlag bool +type SqlValues []string +type SqlArgs []string diff --git a/db/examples/mongo.go b/db/examples/mongo.go deleted file mode 100644 index 4e2ca6b6edb416c46a867f49585b48a676a7a946..0000000000000000000000000000000000000000 --- a/db/examples/mongo.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "fmt" - "github.com/kr/pretty" - . "github.com/xiam/gosexy/db" -) - -func main() { - db := MongoSession(DataSource{Host: "10.0.0.11", Database: "gotest"}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.FindAll( - Relate{ - "lives_in": On{ - db.Collection("places"), - Where{"code_id": "{place_code_id}"}, - }, - }, - RelateAll{ - "has_children": On{ - db.Collection("children"), - Where{"parent_id": "{_id}"}, - }, - "has_visited": On{ - db.Collection("visits"), - Where{"person_id": "{_id}"}, - Relate{ - "place": On{ - db.Collection("places"), - Where{"_id": "{place_id}"}, - }, - }, - }, - }, - ) - - fmt.Printf("%# v\n", pretty.Formatter(result)) -} diff --git a/db/mongo/examples/mongo.go b/db/mongo/examples/mongo.go new file mode 100644 index 0000000000000000000000000000000000000000..ea069b686a409ce79e29126a891cf761eb4e2b0f --- /dev/null +++ b/db/mongo/examples/mongo.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "github.com/kr/pretty" + "github.com/xiam/gosexy/db" + "github.com/xiam/gosexy/db/mongo" +) + +func main() { + sess := mongo.Session(db.DataSource{Host: "10.0.0.11", Database: "gotest"}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.FindAll( + db.Relate{ + "lives_in": db.On{ + sess.Collection("places"), + db.Where{"code_id": "{place_code_id}"}, + }, + }, + db.RelateAll{ + "has_children": db.On{ + sess.Collection("children"), + db.Where{"parent_id": "{_id}"}, + }, + "has_visited": db.On{ + sess.Collection("visits"), + db.Where{"person_id": "{_id}"}, + db.Relate{ + "place": db.On{ + sess.Collection("places"), + db.Where{"_id": "{place_id}"}, + }, + }, + }, + }, + ) + + fmt.Printf("%# v\n", pretty.Formatter(result)) +} diff --git a/db/mongo.go b/db/mongo/mongo.go similarity index 80% rename from db/mongo.go rename to db/mongo/mongo.go index d18da2c7ee83676a4653b3b59ccc01ec96cb2eee..9d65a259f258d0eccb80193fd4a05b596e044825 100644 --- a/db/mongo.go +++ b/db/mongo/mongo.go @@ -21,11 +21,12 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package db +package mongo import ( "fmt" - . "github.com/xiam/gosexy" + "github.com/xiam/gosexy" + "github.com/xiam/gosexy/db" "labix.org/v2/mgo" "labix.org/v2/mgo/bson" "net/url" @@ -37,7 +38,7 @@ import ( // MongoDataSource session. type MongoDataSource struct { - config DataSource + config db.DataSource session *mgo.Session database *mgo.Database } @@ -48,8 +49,8 @@ type MongoDataSourceCollection struct { collection *mgo.Collection } -// Converts Where keytypes into something that mgo can understand. -func (c *MongoDataSourceCollection) marshal(where Where) map[string]interface{} { +// Converts db.Where keytypes into something that mgo can understand. +func (c *MongoDataSourceCollection) marshal(where db.Where) map[string]interface{} { conds := make(map[string]interface{}) for key, val := range where { @@ -118,29 +119,29 @@ func (c *MongoDataSourceCollection) compileConditions(term interface{}) interfac return values } } - case Or: + case db.Or: { values := []interface{}{} - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - values = append(values, c.compileConditions(term.(Or)[i])) + values = append(values, c.compileConditions(term.(db.Or)[i])) } condition := map[string]interface{}{"$or": values} return condition } - case And: + case db.And: { values := []interface{}{} - itop := len(term.(And)) + itop := len(term.(db.And)) for i := 0; i < itop; i++ { - values = append(values, c.compileConditions(term.(And)[i])) + values = append(values, c.compileConditions(term.(db.And)[i])) } condition := map[string]interface{}{"$and": values} return condition } - case Where: + case db.Where: { - return c.marshal(term.(Where)) + return c.marshal(term.(db.Where)) } } return nil @@ -176,7 +177,7 @@ func (c *MongoDataSourceCollection) Remove(terms ...interface{}) bool { return true } -// Updates all the items that match the provided conditions. You can specify the modification type by using Set, Modify or Upsert. +// Updates all the items that match the provided conditions. You can specify the modification type by using db.Set, db.Modify or db.Upsert. func (c *MongoDataSourceCollection) Update(terms ...interface{}) bool { var set interface{} @@ -195,23 +196,23 @@ func (c *MongoDataSourceCollection) Update(terms ...interface{}) bool { term := terms[i] switch term.(type) { - case Set: + case db.Set: { - set = term.(Set) + set = term.(db.Set) } - case Upsert: + case db.Upsert: { - upsert = term.(Upsert) + upsert = term.(db.Upsert) } - case Modify: + case db.Modify: { - modify = term.(Modify) + modify = term.(db.Modify) } } } if set != nil { - c.collection.UpdateAll(query, Item{"$set": set}) + c.collection.UpdateAll(query, db.Item{"$set": set}) return true } @@ -263,18 +264,18 @@ func (c *MongoDataSourceCollection) Count(terms ...interface{}) int { return count } -// Returns a document that matches all the provided conditions. Ordering of the terms doesn't matter but you must take in +// Returns a document that matches all the provided conditions. db.Ordering of the terms doesn't matter but you must take in // account that conditions are generally evaluated from left to right (or from top to bottom). -func (c *MongoDataSourceCollection) Find(terms ...interface{}) Item { +func (c *MongoDataSourceCollection) Find(terms ...interface{}) db.Item { - var item Item + var item db.Item - terms = append(terms, Limit(1)) + terms = append(terms, db.Limit(1)) result := c.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { item = response[0] } @@ -300,17 +301,17 @@ func (c *MongoDataSourceCollection) BuildQuery(terms ...interface{}) *mgo.Query term := terms[i] switch term.(type) { - case Limit: + case db.Limit: { - limit = int(term.(Limit)) + limit = int(term.(db.Limit)) } - case Offset: + case db.Offset: { - offset = int(term.(Offset)) + offset = int(term.(db.Offset)) } - case Sort: + case db.Sort: { - sort = term.(Sort) + sort = term.(db.Sort) } } } @@ -336,8 +337,8 @@ func (c *MongoDataSourceCollection) BuildQuery(terms ...interface{}) *mgo.Query } // Returns all the results that match the provided conditions. See Find(). -func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { - var items []Item +func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []db.Item { + var items []db.Item var result []interface{} var relate interface{} @@ -352,13 +353,13 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { term := terms[i] switch term.(type) { - case Relate: + case db.Relate: { - relate = term.(Relate) + relate = term.(db.Relate) } - case RelateAll: + case db.RelateAll: { - relateAll = term.(RelateAll) + relateAll = term.(db.RelateAll) } } } @@ -370,44 +371,44 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { p.All(&result) - var relations []Tuple + var relations []gosexy.Tuple // This query is related to other collections. if relate != nil { - for rname, rterms := range relate.(Relate) { + for rname, rterms := range relate.(db.Relate) { rcollection := c.parent.Collection(rname) ttop := len(rterms) for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } - relations = append(relations, Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) } } if relateAll != nil { - for rname, rterms := range relateAll.(RelateAll) { + for rname, rterms := range relateAll.(db.RelateAll) { rcollection := c.parent.Collection(rname) ttop := len(rterms) for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } - relations = append(relations, Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) } } @@ -416,11 +417,11 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { jtop := len(relations) itop = len(result) - items = make([]Item, itop) + items = make([]db.Item, itop) for i := 0; i < itop; i++ { - item := Item{} + item := db.Item{} // Default values. for key, val := range result[i].(bson.M) { @@ -434,18 +435,18 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { terms := []interface{}{} - ktop := len(relation["terms"].(On)) + ktop := len(relation["terms"].(db.On)) for k := 0; k < ktop; k++ { //term = tcopy[k] - term = relation["terms"].(On)[k] + term = relation["terms"].(db.On)[k] switch term.(type) { - // Just waiting for Where statements. - case Where: + // Just waiting for db.Where statements. + case db.Where: { - for wkey, wval := range term.(Where) { + for wkey, wval := range term.(db.Where) { //if reflect.TypeOf(wval).Kind() == reflect.String { // does not always work. if reflect.TypeOf(wval).Name() == "string" { // Matching dynamic values. @@ -453,7 +454,7 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { if matched { // Replacing dynamic values. kname := strings.Trim(wval.(string), "{}") - term = Where{wkey: item[kname]} + term = db.Where{wkey: item[kname]} } } } @@ -465,10 +466,10 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { // Executing external query. if relation["all"] == true { value := relation["collection"].(*MongoDataSourceCollection).invoke("FindAll", terms) - item[relation["name"].(string)] = value[0].Interface().([]Item) + item[relation["name"].(string)] = value[0].Interface().([]db.Item) } else { value := relation["collection"].(*MongoDataSourceCollection).invoke("Find", terms) - item[relation["name"].(string)] = value[0].Interface().(Item) + item[relation["name"].(string)] = value[0].Interface().(db.Item) } } @@ -481,21 +482,21 @@ func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { } // Returns a new MongoDataSource object. -func MongoSession(config DataSource) Database { +func Session(config db.DataSource) db.Database { m := &MongoDataSource{} m.config = config return m } -// Switches the current session database to the provided name. See MongoSession(). +// Switches the current session database to the provided name. See Session(). func (m *MongoDataSource) Use(database string) error { m.config.Database = database m.database = m.session.DB(m.config.Database) return nil } -// Returns a Collection from the currently active database given the name. See MongoSession(). -func (m *MongoDataSource) Collection(name string) Collection { +// Returns a Collection from the currently active database given the name. See Session(). +func (m *MongoDataSource) Collection(name string) db.Collection { c := &MongoDataSourceCollection{} c.parent = m c.collection = m.database.C(name) @@ -506,7 +507,7 @@ func (m *MongoDataSource) Driver() interface{} { return m.session } -// Connects to the previously specified datasource. See MongoSession(). +// Connects to the previously specified datasource. See Session(). func (m *MongoDataSource) Open() error { var err error diff --git a/db/mongo/mongo_test.go b/db/mongo/mongo_test.go new file mode 100644 index 0000000000000000000000000000000000000000..42e797d2790e6b9630b21f19430d1dc90a73c4ef --- /dev/null +++ b/db/mongo/mongo_test.go @@ -0,0 +1,238 @@ +package mongo + +import ( + "fmt" + "github.com/kr/pretty" + "github.com/xiam/gosexy/db" + "math/rand" + "testing" +) + +const mgHost = "10.0.0.11" +const mgDatabase = "gotest" + +func TestMgOpen(t *testing.T) { + + sess := Session(db.DataSource{Host: "0.0.0.0"}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + t.Logf("Got %t, this was intended.", err) + return + } + + t.Error("Are you serious?") +} + +func TestMgAuthFail(t *testing.T) { + + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase, User: "unknown", Password: "fail"}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + t.Logf("Got %t, this was intended.", err) + return + } + + t.Error("Are you serious?") +} + +func TestMgDrop(t *testing.T) { + + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + sess.Drop() +} + +func TestMgAppend(t *testing.T) { + + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} + + for i := 0; i < len(names); i++ { + col.Append(db.Item{"name": names[i]}) + } + + if col.Count() != len(names) { + t.Error("Could not append all items.") + } + +} + +func TestMgFind(t *testing.T) { + + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.Find(db.Where{"name": "José"}) + + if result["name"] != "José" { + t.Error("Could not find a recently appended item.") + } + +} + +func TestMgDelete(t *testing.T) { + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Remove(db.Where{"name": "Juan"}) + + result := col.Find(db.Where{"name": "Juan"}) + + if len(result) > 0 { + t.Error("Could not remove a recently appended item.") + } +} + +func TestMgUpdate(t *testing.T) { + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Update(db.Where{"name": "José"}, db.Set{"name": "Joseph"}) + + result := col.Find(db.Where{"name": "Joseph"}) + + if len(result) == 0 { + t.Error("Could not update a recently appended item.") + } +} + +func TestMgPopulate(t *testing.T) { + var i int + + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} + + for i = 0; i < len(places); i++ { + sess.Collection("places").Append(db.Item{ + "code_id": i, + "name": places[i], + }) + } + + people := sess.Collection("people").FindAll() + + for i = 0; i < len(people); i++ { + person := people[i] + + // Has 5 children. + for j := 0; j < 5; j++ { + sess.Collection("children").Append(db.Item{ + "name": fmt.Sprintf("%s's child %d", person["name"], j+1), + "parent_id": person["_id"], + }) + } + + // Lives in + sess.Collection("people").Update( + db.Where{"_id": person["_id"]}, + db.Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, + ) + + // Has visited + for k := 0; k < 3; k++ { + place := sess.Collection("places").Find(db.Where{ + "code_id": int(rand.Float32() * float32(len(places))), + }) + sess.Collection("visits").Append(db.Item{ + "place_id": place["_id"], + "person_id": person["_id"], + }) + } + } + +} + +func TestMgRelation(t *testing.T) { + sess := Session(db.DataSource{Host: mgHost, Database: mgDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.FindAll( + db.Relate{ + "lives_in": db.On{ + sess.Collection("places"), + db.Where{"code_id": "{place_code_id}"}, + }, + }, + db.RelateAll{ + "has_children": db.On{ + sess.Collection("children"), + db.Where{"parent_id": "{_id}"}, + }, + "has_visited": db.On{ + sess.Collection("visits"), + db.Where{"person_id": "{_id}"}, + db.Relate{ + "place": db.On{ + sess.Collection("places"), + db.Where{"_id": "{place_id}"}, + }, + }, + }, + }, + ) + + fmt.Printf("%# v\n", pretty.Formatter(result)) +} diff --git a/db/mongo_test.go b/db/mongo_test.go deleted file mode 100644 index 601d15fa99556120621c852c5df247e67a329e3e..0000000000000000000000000000000000000000 --- a/db/mongo_test.go +++ /dev/null @@ -1,237 +0,0 @@ -package db - -import ( - "fmt" - "github.com/kr/pretty" - "math/rand" - "testing" -) - -const mgHost = "10.0.0.11" -const mgDatabase = "gotest" - -func TestMgOpen(t *testing.T) { - - db := MongoSession(DataSource{Host: "0.0.0.0"}) - - err := db.Open() - defer db.Close() - - if err != nil { - t.Logf("Got %t, this was intended.", err) - return - } - - t.Error("Are you serious?") -} - -func TestMgAuthFail(t *testing.T) { - - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase, User: "unknown", Password: "fail"}) - - err := db.Open() - defer db.Close() - - if err != nil { - t.Logf("Got %t, this was intended.", err) - return - } - - t.Error("Are you serious?") -} - -func TestMgDrop(t *testing.T) { - - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - db.Drop() -} - -func TestMgAppend(t *testing.T) { - - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} - - for i := 0; i < len(names); i++ { - col.Append(Item{"name": names[i]}) - } - - if col.Count() != len(names) { - t.Error("Could not append all items.") - } - -} - -func TestMgFind(t *testing.T) { - - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.Find(Where{"name": "José"}) - - if result["name"] != "José" { - t.Error("Could not find a recently appended item.") - } - -} - -func TestMgDelete(t *testing.T) { - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Remove(Where{"name": "Juan"}) - - result := col.Find(Where{"name": "Juan"}) - - if len(result) > 0 { - t.Error("Could not remove a recently appended item.") - } -} - -func TestMgUpdate(t *testing.T) { - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Update(Where{"name": "José"}, Set{"name": "Joseph"}) - - result := col.Find(Where{"name": "Joseph"}) - - if len(result) == 0 { - t.Error("Could not update a recently appended item.") - } -} - -func TestMgPopulate(t *testing.T) { - var i int - - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} - - for i = 0; i < len(places); i++ { - db.Collection("places").Append(Item{ - "code_id": i, - "name": places[i], - }) - } - - people := db.Collection("people").FindAll() - - for i = 0; i < len(people); i++ { - person := people[i] - - // Has 5 children. - for j := 0; j < 5; j++ { - db.Collection("children").Append(Item{ - "name": fmt.Sprintf("%s's child %d", person["name"], j+1), - "parent_id": person["_id"], - }) - } - - // Lives in - db.Collection("people").Update( - Where{"_id": person["_id"]}, - Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, - ) - - // Has visited - for k := 0; k < 3; k++ { - place := db.Collection("places").Find(Where{ - "code_id": int(rand.Float32() * float32(len(places))), - }) - db.Collection("visits").Append(Item{ - "place_id": place["_id"], - "person_id": person["_id"], - }) - } - } - -} - -func TestMgRelation(t *testing.T) { - db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.FindAll( - Relate{ - "lives_in": On{ - db.Collection("places"), - Where{"code_id": "{place_code_id}"}, - }, - }, - RelateAll{ - "has_children": On{ - db.Collection("children"), - Where{"parent_id": "{_id}"}, - }, - "has_visited": On{ - db.Collection("visits"), - Where{"person_id": "{_id}"}, - Relate{ - "place": On{ - db.Collection("places"), - Where{"_id": "{place_id}"}, - }, - }, - }, - }, - ) - - fmt.Printf("%# v\n", pretty.Formatter(result)) -} diff --git a/db/mysql.go b/db/mysql/mysql.go similarity index 80% rename from db/mysql.go rename to db/mysql/mysql.go index dfc838679c8574bb2f23828a83cdfc1a761be7c2..a50d1baf43bccdbced5b8b3e1c12120f5c8ffdb8 100644 --- a/db/mysql.go +++ b/db/mysql/mysql.go @@ -21,13 +21,14 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package db +package mysql import ( _ "code.google.com/p/go-mysql-driver/mysql" "database/sql" "fmt" - . "github.com/xiam/gosexy" + "github.com/xiam/gosexy" + "github.com/xiam/gosexy/db" "reflect" "regexp" "strconv" @@ -36,7 +37,7 @@ import ( type myQuery struct { Query []string - sqlArgs []string + SqlArgs []string } func myCompile(terms []interface{}) *myQuery { @@ -50,18 +51,18 @@ func myCompile(terms []interface{}) *myQuery { { q.Query = append(q.Query, term.(string)) } - case sqlArgs: + case db.SqlArgs: { - for _, arg := range term.(sqlArgs) { - q.sqlArgs = append(q.sqlArgs, arg) + for _, arg := range term.(db.SqlArgs) { + q.SqlArgs = append(q.SqlArgs, arg) } } - case sqlValues: + case db.SqlValues: { - args := make([]string, len(term.(sqlValues))) - for i, arg := range term.(sqlValues) { + args := make([]string, len(term.(db.SqlValues))) + for i, arg := range term.(db.SqlValues) { args[i] = "?" - q.sqlArgs = append(q.sqlArgs, arg) + q.SqlArgs = append(q.SqlArgs, arg) } q.Query = append(q.Query, "("+strings.Join(args, ", ")+")") } @@ -79,8 +80,8 @@ func myFields(names []string) string { return "(" + strings.Join(names, ", ") + ")" } -func myValues(values []string) sqlValues { - ret := make(sqlValues, len(values)) +func myValues(values []string) db.SqlValues { + ret := make(db.SqlValues, len(values)) for i, _ := range values { ret[i] = values[i] } @@ -90,14 +91,14 @@ func myValues(values []string) sqlValues { // Stores driver's session data. type MysqlDataSource struct { session *sql.DB - config DataSource - collections map[string]Collection + config db.DataSource + collections map[string]db.Collection } // Returns all items from a query. -func (t *MysqlTable) myFetchAll(rows sql.Rows) []Item { +func (t *MysqlTable) myFetchAll(rows sql.Rows) []db.Item { - items := []Item{} + items := []db.Item{} columns, _ := rows.Columns() @@ -118,7 +119,7 @@ func (t *MysqlTable) myFetchAll(rows sql.Rows) []Item { fn := sn.MethodByName("Scan") for rows.Next() { - item := Item{} + item := db.Item{} ret := fn.Call(fargs) @@ -168,15 +169,15 @@ func (my *MysqlDataSource) myExec(method string, terms ...interface{}) sql.Rows /* fmt.Printf("Q: %v\n", q.Query) - fmt.Printf("A: %v\n", q.sqlArgs) + fmt.Printf("A: %v\n", q.SqlArgs) */ - args := make([]reflect.Value, len(q.sqlArgs)+1) + args := make([]reflect.Value, len(q.SqlArgs)+1) args[0] = reflect.ValueOf(strings.Join(q.Query, " ")) - for i := 0; i < len(q.sqlArgs); i++ { - args[1+i] = reflect.ValueOf(q.sqlArgs[i]) + for i := 0; i < len(q.SqlArgs); i++ { + args[1+i] = reflect.ValueOf(q.SqlArgs[i]) } res := fn.Call(args) @@ -196,10 +197,10 @@ type MysqlTable struct { } // Configures and returns a MySQL database session. -func MysqlSession(config DataSource) Database { +func Session(config db.DataSource) db.Database { my := &MysqlDataSource{} my.config = config - my.collections = make(map[string]Collection) + my.collections = make(map[string]db.Collection) return my } @@ -289,9 +290,9 @@ func (t *MysqlTable) invoke(fn string, terms []interface{}) []reflect.Value { } // A helper for preparing queries that use SET. -func (t *MysqlTable) compileSet(term Set) (string, sqlArgs) { +func (t *MysqlTable) compileSet(term db.Set) (string, db.SqlArgs) { sql := []string{} - args := sqlArgs{} + args := db.SqlArgs{} for key, arg := range term { sql = append(sql, fmt.Sprintf("%s = ?", key)) @@ -302,9 +303,9 @@ func (t *MysqlTable) compileSet(term Set) (string, sqlArgs) { } // A helper for preparing queries that have conditions. -func (t *MysqlTable) compileConditions(term interface{}) (string, sqlArgs) { +func (t *MysqlTable) compileConditions(term interface{}) (string, db.SqlArgs) { sql := []string{} - args := sqlArgs{} + args := db.SqlArgs{} switch term.(type) { case []interface{}: @@ -326,13 +327,13 @@ func (t *MysqlTable) compileConditions(term interface{}) (string, sqlArgs) { return "(" + strings.Join(sql, " AND ") + ")", args } } - case Or: + case db.Or: { - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - rsql, rargs := t.compileConditions(term.(Or)[i]) + rsql, rargs := t.compileConditions(term.(db.Or)[i]) if rsql != "" { sql = append(sql, rsql) for j := 0; j < len(rargs); j++ { @@ -345,13 +346,13 @@ func (t *MysqlTable) compileConditions(term interface{}) (string, sqlArgs) { return "(" + strings.Join(sql, " OR ") + ")", args } } - case And: + case db.And: { - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - rsql, rargs := t.compileConditions(term.(Or)[i]) + rsql, rargs := t.compileConditions(term.(db.Or)[i]) if rsql != "" { sql = append(sql, rsql) for j := 0; j < len(rargs); j++ { @@ -364,9 +365,9 @@ func (t *MysqlTable) compileConditions(term interface{}) (string, sqlArgs) { return "(" + strings.Join(sql, " AND ") + ")", args } } - case Where: + case db.Where: { - return t.marshal(term.(Where)) + return t.marshal(term.(db.Where)) } } @@ -374,8 +375,8 @@ func (t *MysqlTable) compileConditions(term interface{}) (string, sqlArgs) { return "", args } -// Converts Where{} structures into SQL before processing them in a query. -func (t *MysqlTable) marshal(where Where) (string, []string) { +// Converts db.Where{} structures into SQL before processing them in a query. +func (t *MysqlTable) marshal(where db.Where) (string, []string) { for key, val := range where { key = strings.Trim(key, " ") @@ -426,15 +427,15 @@ func (t *MysqlTable) Remove(terms ...interface{}) bool { // Modifies all the rows in the table that match certain conditions. func (t *MysqlTable) Update(terms ...interface{}) bool { var fields string - var fargs sqlArgs + var fargs db.SqlArgs conditions, cargs := t.compileConditions(terms) for _, term := range terms { switch term.(type) { - case Set: + case db.Set: { - fields, fargs = t.compileSet(term.(Set)) + fields, fargs = t.compileSet(term.(db.Set)) } } } @@ -453,7 +454,7 @@ func (t *MysqlTable) Update(terms ...interface{}) bool { } // Returns all the rows in the table that match certain conditions. -func (t *MysqlTable) FindAll(terms ...interface{}) []Item { +func (t *MysqlTable) FindAll(terms ...interface{}) []db.Item { var itop int var relate interface{} @@ -471,25 +472,25 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { term := terms[i] switch term.(type) { - case Limit: + case db.Limit: { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) + limit = fmt.Sprintf("LIMIT %v", term.(db.Limit)) } - case Offset: + case db.Offset: { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) + offset = fmt.Sprintf("OFFSET %v", term.(db.Offset)) } - case Fields: + case db.Fields: { - fields = strings.Join(term.(Fields), ", ") + fields = strings.Join(term.(db.Fields), ", ") } - case Relate: + case db.Relate: { - relate = term.(Relate) + relate = term.(db.Relate) } - case RelateAll: + case db.RelateAll: { - relateAll = term.(RelateAll) + relateAll = term.(db.RelateAll) } } } @@ -509,12 +510,12 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { result := t.myFetchAll(rows) - var relations []Tuple - var rcollection Collection + var relations []gosexy.Tuple + var rcollection db.Collection // This query is related to other collections. if relate != nil { - for rname, rterms := range relate.(Relate) { + for rname, rterms := range relate.(db.Relate) { rcollection = nil @@ -522,9 +523,9 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } @@ -533,21 +534,21 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { rcollection = t.parent.Collection(rname) } - relations = append(relations, Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) } } if relateAll != nil { - for rname, rterms := range relateAll.(RelateAll) { + for rname, rterms := range relateAll.(db.RelateAll) { rcollection = nil ttop := len(rterms) for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } @@ -556,7 +557,7 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { rcollection = t.parent.Collection(rname) } - relations = append(relations, Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) } } @@ -565,11 +566,11 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { jtop := len(relations) itop = len(result) - items := make([]Item, itop) + items := make([]db.Item, itop) for i := 0; i < itop; i++ { - item := Item{} + item := db.Item{} // Default values. for key, val := range result[i] { @@ -583,18 +584,18 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { terms := []interface{}{} - ktop := len(relation["terms"].(On)) + ktop := len(relation["terms"].(db.On)) for k := 0; k < ktop; k++ { //term = tcopy[k] - term = relation["terms"].(On)[k] + term = relation["terms"].(db.On)[k] switch term.(type) { - // Just waiting for Where statements. - case Where: + // Just waiting for db.Where statements. + case db.Where: { - for wkey, wval := range term.(Where) { + for wkey, wval := range term.(db.Where) { //if reflect.TypeOf(wval).Kind() == reflect.String { // does not always work. if reflect.TypeOf(wval).Name() == "string" { // Matching dynamic values. @@ -602,7 +603,7 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { if matched { // Replacing dynamic values. kname := strings.Trim(wval.(string), "{}") - term = Where{wkey: item[kname]} + term = db.Where{wkey: item[kname]} } } } @@ -614,10 +615,10 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { // Executing external query. if relation["all"] == true { value := relation["collection"].(*MysqlTable).invoke("FindAll", terms) - item[relation["name"].(string)] = value[0].Interface().([]Item) + item[relation["name"].(string)] = value[0].Interface().([]db.Item) } else { value := relation["collection"].(*MysqlTable).invoke("Find", terms) - item[relation["name"].(string)] = value[0].Interface().(Item) + item[relation["name"].(string)] = value[0].Interface().(db.Item) } } @@ -632,12 +633,12 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { // Returns the number of rows in the current table that match certain conditions. func (t *MysqlTable) Count(terms ...interface{}) int { - terms = append(terms, Fields{"COUNT(1) AS _total"}) + terms = append(terms, db.Fields{"COUNT(1) AS _total"}) result := t.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { val, _ := strconv.Atoi(response[0]["_total"].(string)) return val @@ -648,16 +649,16 @@ func (t *MysqlTable) Count(terms ...interface{}) int { } // Returns the first row in the table that matches certain conditions. -func (t *MysqlTable) Find(terms ...interface{}) Item { +func (t *MysqlTable) Find(terms ...interface{}) db.Item { - var item Item + var item db.Item - terms = append(terms, Limit(1)) + terms = append(terms, db.Limit(1)) result := t.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { item = response[0] } @@ -678,7 +679,7 @@ func (t *MysqlTable) Append(items ...interface{}) bool { item := items[i] - for field, value := range item.(Item) { + for field, value := range item.(db.Item) { fields = append(fields, field) values = append(values, fmt.Sprintf("%v", value)) } @@ -697,7 +698,7 @@ func (t *MysqlTable) Append(items ...interface{}) bool { } // Returns a MySQL table structure by name. -func (my *MysqlDataSource) Collection(name string) Collection { +func (my *MysqlDataSource) Collection(name string) db.Collection { if collection, ok := my.collections[name]; ok == true { return collection diff --git a/db/mysql_test.go b/db/mysql/mysql_test.go similarity index 52% rename from db/mysql_test.go rename to db/mysql/mysql_test.go index e60adc4af41f07c994cd8fabcc4c6ba43d3aac91..bd2ca96dcd1b976f42b5058841790f36cd665f5d 100644 --- a/db/mysql_test.go +++ b/db/mysql/mysql_test.go @@ -1,9 +1,10 @@ -package db +package mysql import ( "database/sql" "fmt" "github.com/kr/pretty" + "github.com/xiam/gosexy/db" "math/rand" "testing" ) @@ -15,19 +16,19 @@ const myPassword = "gopass" func TestMyTruncate(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - collections := db.Collections() + collections := sess.Collections() for _, name := range collections { - col := db.Collection(name) + col := sess.Collection(name) col.Truncate() if col.Count() != 0 { t.Errorf("Could not truncate '%s'.", name) @@ -38,23 +39,23 @@ func TestMyTruncate(t *testing.T) { func TestMyAppend(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - col := db.Collection("people") + col := sess.Collection("people") col.Truncate() names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} for i := 0; i < len(names); i++ { - col.Append(Item{"name": names[i]}) + col.Append(db.Item{"name": names[i]}) } if col.Count() != len(names) { @@ -65,18 +66,18 @@ func TestMyAppend(t *testing.T) { func TestMyFind(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - col := db.Collection("people") + col := sess.Collection("people") - result := col.Find(Where{"name": "José"}) + result := col.Find(db.Where{"name": "José"}) if result["name"] != "José" { t.Error("Could not find a recently appended item.") @@ -85,20 +86,20 @@ func TestMyFind(t *testing.T) { } func TestMyDelete(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - col := db.Collection("people") + col := sess.Collection("people") - col.Remove(Where{"name": "Juan"}) + col.Remove(db.Where{"name": "Juan"}) - result := col.Find(Where{"name": "Juan"}) + result := col.Find(db.Where{"name": "Juan"}) if len(result) > 0 { t.Error("Could not remove a recently appended item.") @@ -106,22 +107,22 @@ func TestMyDelete(t *testing.T) { } func TestMyUpdate(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - db.Use("test") + sess.Use("test") - col := db.Collection("people") + col := sess.Collection("people") - col.Update(Where{"name": "José"}, Set{"name": "Joseph"}) + col.Update(db.Where{"name": "José"}, db.Set{"name": "Joseph"}) - result := col.Find(Where{"name": "Joseph"}) + result := col.Find(db.Where{"name": "Joseph"}) if len(result) == 0 { t.Error("Could not update a recently appended item.") @@ -131,28 +132,28 @@ func TestMyUpdate(t *testing.T) { func TestMyPopulate(t *testing.T) { var i int - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - db.Use("test") + sess.Use("test") places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} for i = 0; i < len(places); i++ { - db.Collection("places").Append(Item{ + sess.Collection("places").Append(db.Item{ "code_id": i, "name": places[i], }) } - people := db.Collection("people").FindAll( - Fields{"id", "name"}, + people := sess.Collection("people").FindAll( + db.Fields{"id", "name"}, ) for i = 0; i < len(people); i++ { @@ -160,24 +161,24 @@ func TestMyPopulate(t *testing.T) { // Has 5 children. for j := 0; j < 5; j++ { - db.Collection("children").Append(Item{ + sess.Collection("children").Append(db.Item{ "name": fmt.Sprintf("%s's child %d", person["name"], j+1), "parent_id": person["id"], }) } // Lives in - db.Collection("people").Update( - Where{"id": person["id"]}, - Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, + sess.Collection("people").Update( + db.Where{"id": person["id"]}, + db.Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, ) // Has visited for k := 0; k < 3; k++ { - place := db.Collection("places").Find(Where{ + place := sess.Collection("places").Find(db.Where{ "code_id": int(rand.Float32() * float32(len(places))), }) - db.Collection("visits").Append(Item{ + sess.Collection("visits").Append(db.Item{ "place_id": place["id"], "person_id": person["id"], }) @@ -187,36 +188,36 @@ func TestMyPopulate(t *testing.T) { } func TestMyRelation(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - col := db.Collection("people") + col := sess.Collection("people") result := col.FindAll( - Relate{ - "lives_in": On{ - db.Collection("places"), - Where{"code_id": "{place_code_id}"}, + db.Relate{ + "lives_in": db.On{ + sess.Collection("places"), + db.Where{"code_id": "{place_code_id}"}, }, }, - RelateAll{ - "has_children": On{ - db.Collection("children"), - Where{"parent_id": "{id}"}, + db.RelateAll{ + "has_children": db.On{ + sess.Collection("children"), + db.Where{"parent_id": "{id}"}, }, - "has_visited": On{ - db.Collection("visits"), - Where{"person_id": "{id}"}, - Relate{ - "place": On{ - db.Collection("places"), - Where{"id": "{place_id}"}, + "has_visited": db.On{ + sess.Collection("visits"), + db.Where{"person_id": "{id}"}, + db.Relate{ + "place": db.On{ + sess.Collection("places"), + db.Where{"id": "{place_id}"}, }, }, }, @@ -227,16 +228,16 @@ func TestMyRelation(t *testing.T) { } func TestCustom(t *testing.T) { - db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + sess := Session(db.DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Open() - defer db.Close() + err := sess.Open() + defer sess.Close() if err != nil { panic(err) } - _, err = db.Driver().(*sql.DB).Query("SELECT NOW()") + _, err = sess.Driver().(*sql.DB).Query("SELECT NOW()") if err != nil { panic(err) diff --git a/db/postgresql.go b/db/postgresql/postgresql.go similarity index 80% rename from db/postgresql.go rename to db/postgresql/postgresql.go index 31697c5e58446118afd1ca638cac46b2bfbf2628..4a8d1afd8710fe6991e7a70b82353d8930e30c4b 100644 --- a/db/postgresql.go +++ b/db/postgresql/postgresql.go @@ -21,13 +21,14 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package db +package postgresql import ( "database/sql" "fmt" _ "github.com/xiam/gopostgresql" - . "github.com/xiam/gosexy" + "github.com/xiam/gosexy" + "github.com/xiam/gosexy/db" "reflect" "regexp" "strconv" @@ -36,7 +37,7 @@ import ( type pgQuery struct { Query []string - sqlArgs []string + SqlArgs []string } func pgCompile(terms []interface{}) *pgQuery { @@ -50,18 +51,18 @@ func pgCompile(terms []interface{}) *pgQuery { { q.Query = append(q.Query, term.(string)) } - case sqlArgs: + case db.SqlArgs: { - for _, arg := range term.(sqlArgs) { - q.sqlArgs = append(q.sqlArgs, arg) + for _, arg := range term.(db.SqlArgs) { + q.SqlArgs = append(q.SqlArgs, arg) } } - case sqlValues: + case db.SqlValues: { - args := make([]string, len(term.(sqlValues))) - for i, arg := range term.(sqlValues) { + args := make([]string, len(term.(db.SqlValues))) + for i, arg := range term.(db.SqlValues) { args[i] = "?" - q.sqlArgs = append(q.sqlArgs, arg) + q.SqlArgs = append(q.SqlArgs, arg) } q.Query = append(q.Query, "("+strings.Join(args, ", ")+")") } @@ -79,8 +80,8 @@ func pgFields(names []string) string { return "(" + strings.Join(names, ", ") + ")" } -func pgValues(values []string) sqlValues { - ret := make(sqlValues, len(values)) +func pgValues(values []string) db.SqlValues { + ret := make(db.SqlValues, len(values)) for i, _ := range values { ret[i] = values[i] } @@ -89,14 +90,14 @@ func pgValues(values []string) sqlValues { // Stores PostgreSQL session data. type PostgresqlDataSource struct { - config DataSource + config db.DataSource session *sql.DB - collections map[string]Collection + collections map[string]db.Collection } -func (t *PostgresqlTable) pgFetchAll(rows sql.Rows) []Item { +func (t *PostgresqlTable) pgFetchAll(rows sql.Rows) []db.Item { - items := []Item{} + items := []db.Item{} columns, _ := rows.Columns() @@ -117,7 +118,7 @@ func (t *PostgresqlTable) pgFetchAll(rows sql.Rows) []Item { fn := sn.MethodByName("Scan") for rows.Next() { - item := Item{} + item := db.Item{} ret := fn.Call(fargs) @@ -165,15 +166,15 @@ func (pg *PostgresqlDataSource) pgExec(method string, terms ...interface{}) sql. q := pgCompile(terms) //fmt.Printf("Q: %v\n", q.Query) - //fmt.Printf("A: %v\n", q.sqlArgs) + //fmt.Printf("A: %v\n", q.SqlArgs) qs := strings.Join(q.Query, " ") - args := make([]reflect.Value, len(q.sqlArgs)+1) + args := make([]reflect.Value, len(q.SqlArgs)+1) - for i := 0; i < len(q.sqlArgs); i++ { + for i := 0; i < len(q.SqlArgs); i++ { qs = strings.Replace(qs, "?", fmt.Sprintf("$%d", i+1), 1) - args[1+i] = reflect.ValueOf(q.sqlArgs[i]) + args[1+i] = reflect.ValueOf(q.SqlArgs[i]) } args[0] = reflect.ValueOf(qs) @@ -195,10 +196,10 @@ type PostgresqlTable struct { } // Configures and returns a PostgreSQL dabase session. -func PostgresqlSession(config DataSource) Database { +func Session(config db.DataSource) db.Database { m := &PostgresqlDataSource{} m.config = config - m.collections = make(map[string]Collection) + m.collections = make(map[string]db.Collection) return m } @@ -285,9 +286,9 @@ func (t *PostgresqlTable) invoke(fn string, terms []interface{}) []reflect.Value return exec } -func (t *PostgresqlTable) compileSet(term Set) (string, sqlArgs) { +func (t *PostgresqlTable) compileSet(term db.Set) (string, db.SqlArgs) { sql := []string{} - args := sqlArgs{} + args := db.SqlArgs{} for key, arg := range term { sql = append(sql, fmt.Sprintf("%s = ?", key)) @@ -297,9 +298,9 @@ func (t *PostgresqlTable) compileSet(term Set) (string, sqlArgs) { return strings.Join(sql, ", "), args } -func (t *PostgresqlTable) compileConditions(term interface{}) (string, sqlArgs) { +func (t *PostgresqlTable) compileConditions(term interface{}) (string, db.SqlArgs) { sql := []string{} - args := sqlArgs{} + args := db.SqlArgs{} switch term.(type) { case []interface{}: @@ -321,13 +322,13 @@ func (t *PostgresqlTable) compileConditions(term interface{}) (string, sqlArgs) return "(" + strings.Join(sql, " AND ") + ")", args } } - case Or: + case db.Or: { - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - rsql, rargs := t.compileConditions(term.(Or)[i]) + rsql, rargs := t.compileConditions(term.(db.Or)[i]) if rsql != "" { sql = append(sql, rsql) for j := 0; j < len(rargs); j++ { @@ -340,13 +341,13 @@ func (t *PostgresqlTable) compileConditions(term interface{}) (string, sqlArgs) return "(" + strings.Join(sql, " OR ") + ")", args } } - case And: + case db.And: { - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - rsql, rargs := t.compileConditions(term.(Or)[i]) + rsql, rargs := t.compileConditions(term.(db.Or)[i]) if rsql != "" { sql = append(sql, rsql) for j := 0; j < len(rargs); j++ { @@ -359,9 +360,9 @@ func (t *PostgresqlTable) compileConditions(term interface{}) (string, sqlArgs) return "(" + strings.Join(sql, " AND ") + ")", args } } - case Where: + case db.Where: { - return t.marshal(term.(Where)) + return t.marshal(term.(db.Where)) } } @@ -369,7 +370,7 @@ func (t *PostgresqlTable) compileConditions(term interface{}) (string, sqlArgs) return "", args } -func (t *PostgresqlTable) marshal(where Where) (string, []string) { +func (t *PostgresqlTable) marshal(where db.Where) (string, []string) { for key, val := range where { key = strings.Trim(key, " ") @@ -420,15 +421,15 @@ func (t *PostgresqlTable) Remove(terms ...interface{}) bool { // Modifies all the rows in the table that match certain conditions. func (t *PostgresqlTable) Update(terms ...interface{}) bool { var fields string - var fargs sqlArgs + var fargs db.SqlArgs conditions, cargs := t.compileConditions(terms) for _, term := range terms { switch term.(type) { - case Set: + case db.Set: { - fields, fargs = t.compileSet(term.(Set)) + fields, fargs = t.compileSet(term.(db.Set)) } } } @@ -447,7 +448,7 @@ func (t *PostgresqlTable) Update(terms ...interface{}) bool { } // Returns all the rows in the table that match certain conditions. -func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { +func (t *PostgresqlTable) FindAll(terms ...interface{}) []db.Item { var itop int var relate interface{} @@ -465,25 +466,25 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { term := terms[i] switch term.(type) { - case Limit: + case db.Limit: { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) + limit = fmt.Sprintf("LIMIT %v", term.(db.Limit)) } - case Offset: + case db.Offset: { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) + offset = fmt.Sprintf("OFFSET %v", term.(db.Offset)) } - case Fields: + case db.Fields: { - fields = strings.Join(term.(Fields), ", ") + fields = strings.Join(term.(db.Fields), ", ") } - case Relate: + case db.Relate: { - relate = term.(Relate) + relate = term.(db.Relate) } - case RelateAll: + case db.RelateAll: { - relateAll = term.(RelateAll) + relateAll = term.(db.RelateAll) } } } @@ -503,12 +504,12 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { result := t.pgFetchAll(rows) - var relations []Tuple - var rcollection Collection + var relations []gosexy.Tuple + var rcollection db.Collection // This query is related to other collections. if relate != nil { - for rname, rterms := range relate.(Relate) { + for rname, rterms := range relate.(db.Relate) { rcollection = nil @@ -516,9 +517,9 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } @@ -527,21 +528,21 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { rcollection = t.parent.Collection(rname) } - relations = append(relations, Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) } } if relateAll != nil { - for rname, rterms := range relateAll.(RelateAll) { + for rname, rterms := range relateAll.(db.RelateAll) { rcollection = nil ttop := len(rterms) for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } @@ -550,7 +551,7 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { rcollection = t.parent.Collection(rname) } - relations = append(relations, Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) } } @@ -559,11 +560,11 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { jtop := len(relations) itop = len(result) - items := make([]Item, itop) + items := make([]db.Item, itop) for i := 0; i < itop; i++ { - item := Item{} + item := db.Item{} // Default values. for key, val := range result[i] { @@ -577,18 +578,18 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { terms := []interface{}{} - ktop := len(relation["terms"].(On)) + ktop := len(relation["terms"].(db.On)) for k := 0; k < ktop; k++ { //term = tcopy[k] - term = relation["terms"].(On)[k] + term = relation["terms"].(db.On)[k] switch term.(type) { - // Just waiting for Where statements. - case Where: + // Just waiting for db.Where statements. + case db.Where: { - for wkey, wval := range term.(Where) { + for wkey, wval := range term.(db.Where) { //if reflect.TypeOf(wval).Kind() == reflect.String { // does not always work. if reflect.TypeOf(wval).Name() == "string" { // Matching dynamic values. @@ -596,7 +597,7 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { if matched { // Replacing dynamic values. kname := strings.Trim(wval.(string), "{}") - term = Where{wkey: item[kname]} + term = db.Where{wkey: item[kname]} } } } @@ -608,10 +609,10 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { // Executing external query. if relation["all"] == true { value := relation["collection"].(*PostgresqlTable).invoke("FindAll", terms) - item[relation["name"].(string)] = value[0].Interface().([]Item) + item[relation["name"].(string)] = value[0].Interface().([]db.Item) } else { value := relation["collection"].(*PostgresqlTable).invoke("Find", terms) - item[relation["name"].(string)] = value[0].Interface().(Item) + item[relation["name"].(string)] = value[0].Interface().(db.Item) } } @@ -626,12 +627,12 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []Item { // Returns the number of rows in the current table that match certain conditions. func (t *PostgresqlTable) Count(terms ...interface{}) int { - terms = append(terms, Fields{"COUNT(1) AS _total"}) + terms = append(terms, db.Fields{"COUNT(1) AS _total"}) result := t.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { val, _ := strconv.Atoi(response[0]["_total"].(string)) return val @@ -642,16 +643,16 @@ func (t *PostgresqlTable) Count(terms ...interface{}) int { } // Returns the first row in the table that matches certain conditions. -func (t *PostgresqlTable) Find(terms ...interface{}) Item { +func (t *PostgresqlTable) Find(terms ...interface{}) db.Item { - var item Item + var item db.Item - terms = append(terms, Limit(1)) + terms = append(terms, db.Limit(1)) result := t.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { item = response[0] } @@ -672,7 +673,7 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool { item := items[i] - for field, value := range item.(Item) { + for field, value := range item.(db.Item) { fields = append(fields, field) values = append(values, fmt.Sprintf("%v", value)) } @@ -691,7 +692,7 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool { } // Returns a MySQL table structure by name. -func (pg *PostgresqlDataSource) Collection(name string) Collection { +func (pg *PostgresqlDataSource) Collection(name string) db.Collection { if collection, ok := pg.collections[name]; ok == true { return collection @@ -706,7 +707,7 @@ func (pg *PostgresqlDataSource) Collection(name string) Collection { rows := t.parent.pgExec( "Query", - "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", sqlArgs{t.name}, + "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", db.SqlArgs{t.name}, ) columns := t.pgFetchAll(rows) diff --git a/db/postgresql/postgresql_test.go b/db/postgresql/postgresql_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0e4504297c6ca4a51c387bb3c2544e1323dedbd3 --- /dev/null +++ b/db/postgresql/postgresql_test.go @@ -0,0 +1,227 @@ +package postgresql + +import ( + "fmt" + "github.com/kr/pretty" + "github.com/xiam/gosexy/db" + "math/rand" + "testing" +) + +const pgHost = "10.0.0.11" +const pgDatabase = "gotest" +const pgUser = "gouser" +const pgPassword = "gopass" + +func TestPgTruncate(t *testing.T) { + + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + collections := sess.Collections() + + for _, name := range collections { + col := sess.Collection(name) + col.Truncate() + if col.Count() != 0 { + t.Errorf("Could not truncate '%s'.", name) + } + } + +} + +func TestPgAppend(t *testing.T) { + + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Truncate() + + names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} + + for i := 0; i < len(names); i++ { + col.Append(db.Item{"name": names[i]}) + } + + if col.Count() != len(names) { + t.Error("Could not append all items.") + } + +} + +func TestPgFind(t *testing.T) { + + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.Find(db.Where{"name": "José"}) + + if result["name"] != "José" { + t.Error("Could not find a recently appended item.") + } + +} + +func TestPgDelete(t *testing.T) { + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Remove(db.Where{"name": "Juan"}) + + result := col.Find(db.Where{"name": "Juan"}) + + if len(result) > 0 { + t.Error("Could not remove a recently appended item.") + } +} + +func TestPgUpdate(t *testing.T) { + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + sess.Use("gotest") + + col := sess.Collection("people") + + col.Update(db.Where{"name": "José"}, db.Set{"name": "Joseph"}) + + result := col.Find(db.Where{"name": "Joseph"}) + + if len(result) == 0 { + t.Error("Could not update a recently appended item.") + } +} + +func TestPgPopulate(t *testing.T) { + var i int + + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + sess.Use("gotest") + + places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} + + for i = 0; i < len(places); i++ { + sess.Collection("places").Append(db.Item{ + "code_id": i, + "name": places[i], + }) + } + + people := sess.Collection("people").FindAll( + db.Fields{"id", "name"}, + ) + + for i = 0; i < len(people); i++ { + person := people[i] + + // Has 5 children. + for j := 0; j < 5; j++ { + sess.Collection("children").Append(db.Item{ + "name": fmt.Sprintf("%s's child %d", person["name"], j+1), + "parent_id": person["id"], + }) + } + + // Lives in + sess.Collection("people").Update( + db.Where{"id": person["id"]}, + db.Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, + ) + + // Has visited + for k := 0; k < 3; k++ { + place := sess.Collection("places").Find(db.Where{ + "code_id": int(rand.Float32() * float32(len(places))), + }) + sess.Collection("visits").Append(db.Item{ + "place_id": place["id"], + "person_id": person["id"], + }) + } + } + +} + +func TestPgRelation(t *testing.T) { + sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.FindAll( + db.Relate{ + "lives_in": db.On{ + sess.Collection("places"), + db.Where{"code_id": "{place_code_id}"}, + }, + }, + db.RelateAll{ + "has_children": db.On{ + sess.Collection("children"), + db.Where{"parent_id": "{id}"}, + }, + "has_visited": db.On{ + sess.Collection("visits"), + db.Where{"person_id": "{id}"}, + db.Relate{ + "place": db.On{ + sess.Collection("places"), + db.Where{"id": "{place_id}"}, + }, + }, + }, + }, + ) + + fmt.Printf("%# v\n", pretty.Formatter(result)) +} diff --git a/db/postgresql_test.go b/db/postgresql_test.go deleted file mode 100644 index 7c5f36fdd2b007b10b8947e07ce2936f64b74f73..0000000000000000000000000000000000000000 --- a/db/postgresql_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package db - -import ( - "fmt" - "github.com/kr/pretty" - "math/rand" - "testing" -) - -const pgHost = "10.0.0.11" -const pgDatabase = "gotest" -const pgUser = "gouser" -const pgPassword = "gopass" - -func TestPgTruncate(t *testing.T) { - - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - collections := db.Collections() - - for _, name := range collections { - col := db.Collection(name) - col.Truncate() - if col.Count() != 0 { - t.Errorf("Could not truncate '%s'.", name) - } - } - -} - -func TestPgAppend(t *testing.T) { - - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Truncate() - - names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} - - for i := 0; i < len(names); i++ { - col.Append(Item{"name": names[i]}) - } - - if col.Count() != len(names) { - t.Error("Could not append all items.") - } - -} - -func TestPgFind(t *testing.T) { - - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.Find(Where{"name": "José"}) - - if result["name"] != "José" { - t.Error("Could not find a recently appended item.") - } - -} - -func TestPgDelete(t *testing.T) { - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Remove(Where{"name": "Juan"}) - - result := col.Find(Where{"name": "Juan"}) - - if len(result) > 0 { - t.Error("Could not remove a recently appended item.") - } -} - -func TestPgUpdate(t *testing.T) { - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - db.Use("gotest") - - col := db.Collection("people") - - col.Update(Where{"name": "José"}, Set{"name": "Joseph"}) - - result := col.Find(Where{"name": "Joseph"}) - - if len(result) == 0 { - t.Error("Could not update a recently appended item.") - } -} - -func TestPgPopulate(t *testing.T) { - var i int - - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - db.Use("gotest") - - places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} - - for i = 0; i < len(places); i++ { - db.Collection("places").Append(Item{ - "code_id": i, - "name": places[i], - }) - } - - people := db.Collection("people").FindAll( - Fields{"id", "name"}, - ) - - for i = 0; i < len(people); i++ { - person := people[i] - - // Has 5 children. - for j := 0; j < 5; j++ { - db.Collection("children").Append(Item{ - "name": fmt.Sprintf("%s's child %d", person["name"], j+1), - "parent_id": person["id"], - }) - } - - // Lives in - db.Collection("people").Update( - Where{"id": person["id"]}, - Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, - ) - - // Has visited - for k := 0; k < 3; k++ { - place := db.Collection("places").Find(Where{ - "code_id": int(rand.Float32() * float32(len(places))), - }) - db.Collection("visits").Append(Item{ - "place_id": place["id"], - "person_id": person["id"], - }) - } - } - -} - -func TestPgRelation(t *testing.T) { - db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.FindAll( - Relate{ - "lives_in": On{ - db.Collection("places"), - Where{"code_id": "{place_code_id}"}, - }, - }, - RelateAll{ - "has_children": On{ - db.Collection("children"), - Where{"parent_id": "{id}"}, - }, - "has_visited": On{ - db.Collection("visits"), - Where{"person_id": "{id}"}, - Relate{ - "place": On{ - db.Collection("places"), - Where{"id": "{place_id}"}, - }, - }, - }, - }, - ) - - fmt.Printf("%# v\n", pretty.Formatter(result)) -} diff --git a/db/sqlite/dumps/gotest.sqlite3.db b/db/sqlite/dumps/gotest.sqlite3.db new file mode 100644 index 0000000000000000000000000000000000000000..c0aca4901735ef6aa1dfaef310e3a6f928eb940e Binary files /dev/null and b/db/sqlite/dumps/gotest.sqlite3.db differ diff --git a/db/sqlite.go b/db/sqlite/sqlite.go similarity index 80% rename from db/sqlite.go rename to db/sqlite/sqlite.go index 1a0d4c8c3242abf2fd6e3bd168af89653f7bfa2d..dd912926dcc546e64848c8b2bf3b8de6fc609af2 100644 --- a/db/sqlite.go +++ b/db/sqlite/sqlite.go @@ -21,12 +21,13 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -package db +package sqlite import ( "database/sql" "fmt" - . "github.com/xiam/gosexy" + "github.com/xiam/gosexy" + "github.com/xiam/gosexy/db" _ "github.com/xiam/gosqlite3" "reflect" "regexp" @@ -36,7 +37,7 @@ import ( type slQuery struct { Query []string - sqlArgs []string + SqlArgs []string } func slCompile(terms []interface{}) *slQuery { @@ -50,18 +51,18 @@ func slCompile(terms []interface{}) *slQuery { { q.Query = append(q.Query, term.(string)) } - case sqlArgs: + case db.SqlArgs: { - for _, arg := range term.(sqlArgs) { - q.sqlArgs = append(q.sqlArgs, arg) + for _, arg := range term.(db.SqlArgs) { + q.SqlArgs = append(q.SqlArgs, arg) } } - case sqlValues: + case db.SqlValues: { - args := make([]string, len(term.(sqlValues))) - for i, arg := range term.(sqlValues) { + args := make([]string, len(term.(db.SqlValues))) + for i, arg := range term.(db.SqlValues) { args[i] = "?" - q.sqlArgs = append(q.sqlArgs, arg) + q.SqlArgs = append(q.SqlArgs, arg) } q.Query = append(q.Query, "("+strings.Join(args, ", ")+")") } @@ -79,8 +80,8 @@ func slFields(names []string) string { return "(" + strings.Join(names, ", ") + ")" } -func slValues(values []string) sqlValues { - ret := make(sqlValues, len(values)) +func slValues(values []string) db.SqlValues { + ret := make(db.SqlValues, len(values)) for i, _ := range values { ret[i] = values[i] } @@ -89,14 +90,14 @@ func slValues(values []string) sqlValues { // Stores driver's session data. type SqliteDataSource struct { - config DataSource + config db.DataSource session *sql.DB - collections map[string]Collection + collections map[string]db.Collection } -func (t *SqliteTable) slFetchAll(rows sql.Rows) []Item { +func (t *SqliteTable) slFetchAll(rows sql.Rows) []db.Item { - items := []Item{} + items := []db.Item{} columns, _ := rows.Columns() @@ -117,7 +118,7 @@ func (t *SqliteTable) slFetchAll(rows sql.Rows) []Item { fn := sn.MethodByName("Scan") for rows.Next() { - item := Item{} + item := db.Item{} ret := fn.Call(fargs) @@ -168,15 +169,15 @@ func (sl *SqliteDataSource) slExec(method string, terms ...interface{}) sql.Rows /* fmt.Printf("Q: %v\n", q.Query) - fmt.Printf("A: %v\n", q.sqlArgs) + fmt.Printf("A: %v\n", q.SqlArgs) */ - args := make([]reflect.Value, len(q.sqlArgs)+1) + args := make([]reflect.Value, len(q.SqlArgs)+1) args[0] = reflect.ValueOf(strings.Join(q.Query, " ")) - for i := 0; i < len(q.sqlArgs); i++ { - args[1+i] = reflect.ValueOf(q.sqlArgs[i]) + for i := 0; i < len(q.SqlArgs); i++ { + args[1+i] = reflect.ValueOf(q.SqlArgs[i]) } res := fn.Call(args) @@ -206,10 +207,10 @@ type SqliteTable struct { } // Configures and returns a SQLite database session. -func SqliteSession(config DataSource) Database { +func SqliteSession(config db.DataSource) db.Database { m := &SqliteDataSource{} m.config = config - m.collections = make(map[string]Collection) + m.collections = make(map[string]db.Collection) return m } @@ -290,9 +291,9 @@ func (t *SqliteTable) invoke(fn string, terms []interface{}) []reflect.Value { return exec } -func (t *SqliteTable) compileSet(term Set) (string, sqlArgs) { +func (t *SqliteTable) compileSet(term db.Set) (string, db.SqlArgs) { sql := []string{} - args := sqlArgs{} + args := db.SqlArgs{} for key, arg := range term { sql = append(sql, fmt.Sprintf("%s = ?", key)) @@ -302,9 +303,9 @@ func (t *SqliteTable) compileSet(term Set) (string, sqlArgs) { return strings.Join(sql, ", "), args } -func (t *SqliteTable) compileConditions(term interface{}) (string, sqlArgs) { +func (t *SqliteTable) compileConditions(term interface{}) (string, db.SqlArgs) { sql := []string{} - args := sqlArgs{} + args := db.SqlArgs{} switch term.(type) { case []interface{}: @@ -326,13 +327,13 @@ func (t *SqliteTable) compileConditions(term interface{}) (string, sqlArgs) { return "(" + strings.Join(sql, " AND ") + ")", args } } - case Or: + case db.Or: { - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - rsql, rargs := t.compileConditions(term.(Or)[i]) + rsql, rargs := t.compileConditions(term.(db.Or)[i]) if rsql != "" { sql = append(sql, rsql) for j := 0; j < len(rargs); j++ { @@ -345,13 +346,13 @@ func (t *SqliteTable) compileConditions(term interface{}) (string, sqlArgs) { return "(" + strings.Join(sql, " OR ") + ")", args } } - case And: + case db.And: { - itop := len(term.(Or)) + itop := len(term.(db.Or)) for i := 0; i < itop; i++ { - rsql, rargs := t.compileConditions(term.(Or)[i]) + rsql, rargs := t.compileConditions(term.(db.Or)[i]) if rsql != "" { sql = append(sql, rsql) for j := 0; j < len(rargs); j++ { @@ -364,9 +365,9 @@ func (t *SqliteTable) compileConditions(term interface{}) (string, sqlArgs) { return "(" + strings.Join(sql, " AND ") + ")", args } } - case Where: + case db.Where: { - return t.marshal(term.(Where)) + return t.marshal(term.(db.Where)) } } @@ -374,7 +375,7 @@ func (t *SqliteTable) compileConditions(term interface{}) (string, sqlArgs) { return "", args } -func (t *SqliteTable) marshal(where Where) (string, []string) { +func (t *SqliteTable) marshal(where db.Where) (string, []string) { for key, val := range where { key = strings.Trim(key, " ") @@ -425,15 +426,15 @@ func (t *SqliteTable) Remove(terms ...interface{}) bool { // Modifies all the rows in the table that match certain conditions. func (t *SqliteTable) Update(terms ...interface{}) bool { var fields string - var fargs sqlArgs + var fargs db.SqlArgs conditions, cargs := t.compileConditions(terms) for _, term := range terms { switch term.(type) { - case Set: + case db.Set: { - fields, fargs = t.compileSet(term.(Set)) + fields, fargs = t.compileSet(term.(db.Set)) } } } @@ -452,7 +453,7 @@ func (t *SqliteTable) Update(terms ...interface{}) bool { } // Returns all the rows in the table that match certain conditions. -func (t *SqliteTable) FindAll(terms ...interface{}) []Item { +func (t *SqliteTable) FindAll(terms ...interface{}) []db.Item { var itop int var relate interface{} @@ -470,25 +471,25 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { term := terms[i] switch term.(type) { - case Limit: + case db.Limit: { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) + limit = fmt.Sprintf("LIMIT %v", term.(db.Limit)) } - case Offset: + case db.Offset: { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) + offset = fmt.Sprintf("OFFSET %v", term.(db.Offset)) } - case Fields: + case db.Fields: { - fields = strings.Join(term.(Fields), ", ") + fields = strings.Join(term.(db.Fields), ", ") } - case Relate: + case db.Relate: { - relate = term.(Relate) + relate = term.(db.Relate) } - case RelateAll: + case db.RelateAll: { - relateAll = term.(RelateAll) + relateAll = term.(db.RelateAll) } } } @@ -508,12 +509,12 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { result := t.slFetchAll(rows) - var relations []Tuple - var rcollection Collection + var relations []gosexy.Tuple + var rcollection db.Collection // This query is related to other collections. if relate != nil { - for rname, rterms := range relate.(Relate) { + for rname, rterms := range relate.(db.Relate) { rcollection = nil @@ -521,9 +522,9 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } @@ -532,21 +533,21 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { rcollection = t.parent.Collection(rname) } - relations = append(relations, Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": false, "name": rname, "collection": rcollection, "terms": rterms}) } } if relateAll != nil { - for rname, rterms := range relateAll.(RelateAll) { + for rname, rterms := range relateAll.(db.RelateAll) { rcollection = nil ttop := len(rterms) for t := ttop - 1; t >= 0; t-- { rterm := rterms[t] switch rterm.(type) { - case Collection: + case db.Collection: { - rcollection = rterm.(Collection) + rcollection = rterm.(db.Collection) } } } @@ -555,7 +556,7 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { rcollection = t.parent.Collection(rname) } - relations = append(relations, Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) + relations = append(relations, gosexy.Tuple{"all": true, "name": rname, "collection": rcollection, "terms": rterms}) } } @@ -564,11 +565,11 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { jtop := len(relations) itop = len(result) - items := make([]Item, itop) + items := make([]db.Item, itop) for i := 0; i < itop; i++ { - item := Item{} + item := db.Item{} // Default values. for key, val := range result[i] { @@ -582,18 +583,18 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { terms := []interface{}{} - ktop := len(relation["terms"].(On)) + ktop := len(relation["terms"].(db.On)) for k := 0; k < ktop; k++ { //term = tcopy[k] - term = relation["terms"].(On)[k] + term = relation["terms"].(db.On)[k] switch term.(type) { - // Just waiting for Where statements. - case Where: + // Just waiting for db.Where statements. + case db.Where: { - for wkey, wval := range term.(Where) { + for wkey, wval := range term.(db.Where) { //if reflect.TypeOf(wval).Kind() == reflect.String { // does not always work. if reflect.TypeOf(wval).Name() == "string" { // Matching dynamic values. @@ -601,7 +602,7 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { if matched { // Replacing dynamic values. kname := strings.Trim(wval.(string), "{}") - term = Where{wkey: item[kname]} + term = db.Where{wkey: item[kname]} } } } @@ -613,10 +614,10 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { // Executing external query. if relation["all"] == true { value := relation["collection"].(*SqliteTable).invoke("FindAll", terms) - item[relation["name"].(string)] = value[0].Interface().([]Item) + item[relation["name"].(string)] = value[0].Interface().([]db.Item) } else { value := relation["collection"].(*SqliteTable).invoke("Find", terms) - item[relation["name"].(string)] = value[0].Interface().(Item) + item[relation["name"].(string)] = value[0].Interface().(db.Item) } } @@ -631,12 +632,12 @@ func (t *SqliteTable) FindAll(terms ...interface{}) []Item { // Returns the number of rows in the current table that match certain conditions. func (t *SqliteTable) Count(terms ...interface{}) int { - terms = append(terms, Fields{"COUNT(1) AS _total"}) + terms = append(terms, db.Fields{"COUNT(1) AS _total"}) result := t.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { val, _ := strconv.Atoi(response[0]["_total"].(string)) return val @@ -647,16 +648,16 @@ func (t *SqliteTable) Count(terms ...interface{}) int { } // Returns the first row in the table that matches certain conditions. -func (t *SqliteTable) Find(terms ...interface{}) Item { +func (t *SqliteTable) Find(terms ...interface{}) db.Item { - var item Item + var item db.Item - terms = append(terms, Limit(1)) + terms = append(terms, db.Limit(1)) result := t.invoke("FindAll", terms) if len(result) > 0 { - response := result[0].Interface().([]Item) + response := result[0].Interface().([]db.Item) if len(response) > 0 { item = response[0] } @@ -677,7 +678,7 @@ func (t *SqliteTable) Append(items ...interface{}) bool { item := items[i] - for field, value := range item.(Item) { + for field, value := range item.(db.Item) { fields = append(fields, field) values = append(values, fmt.Sprintf("%v", value)) } @@ -697,7 +698,7 @@ func (t *SqliteTable) Append(items ...interface{}) bool { } // Returns a SQLite table structure by name. -func (sl *SqliteDataSource) Collection(name string) Collection { +func (sl *SqliteDataSource) Collection(name string) db.Collection { if collection, ok := sl.collections[name]; ok == true { return collection diff --git a/db/sqlite/sqlite_test.go b/db/sqlite/sqlite_test.go new file mode 100644 index 0000000000000000000000000000000000000000..528a83aca11c08315e4f738766baabe2a4dd7ce5 --- /dev/null +++ b/db/sqlite/sqlite_test.go @@ -0,0 +1,222 @@ +package sqlite + +import ( + "fmt" + "github.com/kr/pretty" + "github.com/xiam/gosexy/db" + "math/rand" + "testing" +) + +const sqDatabase = "./dumps/gotest.sqlite3.db" + +func TestSqTruncate(t *testing.T) { + + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + collections := sess.Collections() + + for _, name := range collections { + col := sess.Collection(name) + col.Truncate() + if col.Count() != 0 { + t.Errorf("Could not truncate '%s'.", name) + } + } + +} + +func TestSqAppend(t *testing.T) { + + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Truncate() + + names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} + + for i := 0; i < len(names); i++ { + col.Append(db.Item{"name": names[i]}) + } + + if col.Count() != len(names) { + panic(fmt.Errorf("Could not append all items")) + } + +} + +func TestSqFind(t *testing.T) { + + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.Find(db.Where{"name": "José"}) + + if result["name"] != "José" { + t.Error("Could not find a recently appended item.") + } + +} + +func TestSqDelete(t *testing.T) { + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Remove(db.Where{"name": "Juan"}) + + result := col.Find(db.Where{"name": "Juan"}) + + if len(result) > 0 { + t.Error("Could not remove a recently appended item.") + } +} + +func TestSqUpdate(t *testing.T) { + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + col.Update(db.Where{"name": "José"}, db.Set{"name": "Joseph"}) + + result := col.Find(db.Where{"name": "Joseph"}) + + if len(result) == 0 { + t.Error("Could not update a recently appended item.") + } +} + +func TestSqPopulate(t *testing.T) { + var i int + + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + sess.Use("test") + + places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} + + for i = 0; i < len(places); i++ { + sess.Collection("places").Append(db.Item{ + "code_id": i, + "name": places[i], + }) + } + + people := sess.Collection("people").FindAll( + db.Fields{"id", "name"}, + ) + + for i = 0; i < len(people); i++ { + person := people[i] + + // Has 5 children. + for j := 0; j < 5; j++ { + sess.Collection("children").Append(db.Item{ + "name": fmt.Sprintf("%s's child %d", person["name"], j+1), + "parent_id": person["id"], + }) + } + + // Lives in + sess.Collection("people").Update( + db.Where{"id": person["id"]}, + db.Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, + ) + + // Has visited + for k := 0; k < 3; k++ { + place := sess.Collection("places").Find(db.Where{ + "code_id": int(rand.Float32() * float32(len(places))), + }) + sess.Collection("visits").Append(db.Item{ + "place_id": place["id"], + "person_id": person["id"], + }) + } + } + +} + +func TestSqRelation(t *testing.T) { + sess := SqliteSession(db.DataSource{Database: sqDatabase}) + + err := sess.Open() + defer sess.Close() + + if err != nil { + panic(err) + } + + col := sess.Collection("people") + + result := col.FindAll( + db.Relate{ + "lives_in": db.On{ + sess.Collection("places"), + db.Where{"code_id": "{place_code_id}"}, + }, + }, + db.RelateAll{ + "has_children": db.On{ + sess.Collection("children"), + db.Where{"parent_id": "{id}"}, + }, + "has_visited": db.On{ + sess.Collection("visits"), + db.Where{"person_id": "{id}"}, + db.Relate{ + "place": db.On{ + sess.Collection("places"), + db.Where{"id": "{place_id}"}, + }, + }, + }, + }, + ) + + fmt.Printf("%# v\n", pretty.Formatter(result)) +} diff --git a/db/sqlite_test.go b/db/sqlite_test.go deleted file mode 100644 index 480d94f773e6693010651fd67de45097b05d52b1..0000000000000000000000000000000000000000 --- a/db/sqlite_test.go +++ /dev/null @@ -1,221 +0,0 @@ -package db - -import ( - "fmt" - "github.com/kr/pretty" - "math/rand" - "testing" -) - -const sqDatabase = "./dumps/gotest.sqlite3.db" - -func TestSqTruncate(t *testing.T) { - - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - collections := db.Collections() - - for _, name := range collections { - col := db.Collection(name) - col.Truncate() - if col.Count() != 0 { - t.Errorf("Could not truncate '%s'.", name) - } - } - -} - -func TestSqAppend(t *testing.T) { - - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Truncate() - - names := []string{"Juan", "José", "Pedro", "MarÃa", "Roberto", "Manuel", "Miguel"} - - for i := 0; i < len(names); i++ { - col.Append(Item{"name": names[i]}) - } - - if col.Count() != len(names) { - panic(fmt.Errorf("Could not append all items")) - } - -} - -func TestSqFind(t *testing.T) { - - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.Find(Where{"name": "José"}) - - if result["name"] != "José" { - t.Error("Could not find a recently appended item.") - } - -} - -func TestSqDelete(t *testing.T) { - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Remove(Where{"name": "Juan"}) - - result := col.Find(Where{"name": "Juan"}) - - if len(result) > 0 { - t.Error("Could not remove a recently appended item.") - } -} - -func TestSqUpdate(t *testing.T) { - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - col.Update(Where{"name": "José"}, Set{"name": "Joseph"}) - - result := col.Find(Where{"name": "Joseph"}) - - if len(result) == 0 { - t.Error("Could not update a recently appended item.") - } -} - -func TestSqPopulate(t *testing.T) { - var i int - - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - db.Use("test") - - places := []string{"Alaska", "Nebraska", "Alaska", "Acapulco", "Rome", "Singapore", "Alabama", "Cancún"} - - for i = 0; i < len(places); i++ { - db.Collection("places").Append(Item{ - "code_id": i, - "name": places[i], - }) - } - - people := db.Collection("people").FindAll( - Fields{"id", "name"}, - ) - - for i = 0; i < len(people); i++ { - person := people[i] - - // Has 5 children. - for j := 0; j < 5; j++ { - db.Collection("children").Append(Item{ - "name": fmt.Sprintf("%s's child %d", person["name"], j+1), - "parent_id": person["id"], - }) - } - - // Lives in - db.Collection("people").Update( - Where{"id": person["id"]}, - Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, - ) - - // Has visited - for k := 0; k < 3; k++ { - place := db.Collection("places").Find(Where{ - "code_id": int(rand.Float32() * float32(len(places))), - }) - db.Collection("visits").Append(Item{ - "place_id": place["id"], - "person_id": person["id"], - }) - } - } - -} - -func TestSqRelation(t *testing.T) { - db := SqliteSession(DataSource{Database: sqDatabase}) - - err := db.Open() - defer db.Close() - - if err != nil { - panic(err) - } - - col := db.Collection("people") - - result := col.FindAll( - Relate{ - "lives_in": On{ - db.Collection("places"), - Where{"code_id": "{place_code_id}"}, - }, - }, - RelateAll{ - "has_children": On{ - db.Collection("children"), - Where{"parent_id": "{id}"}, - }, - "has_visited": On{ - db.Collection("visits"), - Where{"person_id": "{id}"}, - Relate{ - "place": On{ - db.Collection("places"), - Where{"id": "{place_id}"}, - }, - }, - }, - }, - ) - - fmt.Printf("%# v\n", pretty.Formatter(result)) -}