From 4d451a6fbc31556c7a47cf2f0ef54cac1d6c1b77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <xiam@menteslibres.org> Date: Sun, 8 Jul 2012 22:54:34 -0500 Subject: [PATCH] Removing RemoveAll() and UpdateAll() since they are not consistent among databases, adding a Driver() method that returns a raw driver pointer, starting to fix documentation... --- db/db.go | 16 ++--- db/mongo.go | 161 +++++++++++++++--------------------------- db/mongo_test.go | 47 +++++++----- db/mysql.go | 161 ++++++++++++++++++------------------------ db/mysql_test.go | 54 ++++++++++---- db/postgresql.go | 135 ++++++++++++----------------------- db/postgresql_test.go | 37 ++++++---- db/sqlite.go | 129 ++++++++++++--------------------- db/sqlite_test.go | 43 ++++++----- 9 files changed, 337 insertions(+), 446 deletions(-) diff --git a/db/db.go b/db/db.go index 2aac2846..0dc5ca5d 100644 --- a/db/db.go +++ b/db/db.go @@ -165,7 +165,10 @@ type DataSource struct { // Database methods. type Database interface { - Connect() error + Driver() interface{} + + Open() error + Close() error Collection(string) Collection Collections() []string @@ -184,19 +187,16 @@ type Collection interface { FindAll(...interface{}) []Item Update(...interface{}) bool - UpdateAll(...interface{}) bool Remove(...interface{}) bool - RemoveAll(...interface{}) bool Truncate() bool } -// Specifies single or multiple requests in FindAll() expressions. -type multiFlag bool - // Specifies which fields to return in a query. type Fields []string -type Values []string -type Args []string +// Specifies single or multiple requests in FindAll() expressions. +type multiFlag bool +type sqlValues []string +type sqlArgs []string diff --git a/db/mongo.go b/db/mongo.go index 0e6b52e8..53db7770 100644 --- a/db/mongo.go +++ b/db/mongo.go @@ -35,21 +35,21 @@ import ( "time" ) -// MongoDB session. -type MongoDB struct { - config *DataSource +// MongoDataSource session. +type MongoDataSource struct { + config DataSource session *mgo.Session database *mgo.Database } -// MongoDB collection. -type MongoDBCollection struct { - parent *MongoDB +// MongoDataSource collection. +type MongoDataSourceCollection struct { + parent *MongoDataSource collection *mgo.Collection } // Converts Where keytypes into something that mgo can understand. -func (c *MongoDBCollection) marshal(where Where) map[string]interface{} { +func (c *MongoDataSourceCollection) marshal(where Where) map[string]interface{} { conds := make(map[string]interface{}) for key, val := range where { @@ -67,8 +67,8 @@ func (c *MongoDBCollection) marshal(where Where) map[string]interface{} { return conds } -// Deletes all rows in a collection. In MongoDB, deletes the whole collection. -func (c *MongoDBCollection) Truncate() bool { +// Deletes all rows in a collection. In MongoDataSource, deletes the whole collection. +func (c *MongoDataSourceCollection) Truncate() bool { err := c.collection.DropCollection() if err == nil { @@ -83,7 +83,7 @@ func (c *MongoDBCollection) Truncate() bool { // Example: // // collection.Append(Item { "name": "Peter" }) -func (c *MongoDBCollection) Append(items ...interface{}) bool { +func (c *MongoDataSourceCollection) Append(items ...interface{}) bool { parent := reflect.TypeOf(c.collection) method, _ := parent.MethodByName("Insert") @@ -106,7 +106,7 @@ func (c *MongoDBCollection) Append(items ...interface{}) bool { } // Compiles terms into conditions that mgo can understand. -func (c *MongoDBCollection) compileConditions(term interface{}) interface{} { +func (c *MongoDataSourceCollection) compileConditions(term interface{}) interface{} { switch term.(type) { case []interface{}: { @@ -151,7 +151,7 @@ func (c *MongoDBCollection) compileConditions(term interface{}) interface{} { } // Compiles terms into a query that mgo can understand. -func (c *MongoDBCollection) compileQuery(terms []interface{}) interface{} { +func (c *MongoDataSourceCollection) compileQuery(terms []interface{}) interface{} { var query interface{} compiled := c.compileConditions(terms) @@ -170,16 +170,6 @@ func (c *MongoDBCollection) compileQuery(terms []interface{}) interface{} { return query } -// Removes all the items that match the condition. See Remove(). -func (c *MongoDBCollection) RemoveAll(terms ...interface{}) bool { - - terms = append(terms, multiFlag(true)) - - result := c.invoke("Remove", terms) - - return result[0].Bool() -} - // Removes the first item that matches the provided conditions. // // Example: @@ -188,44 +178,15 @@ func (c *MongoDBCollection) RemoveAll(terms ...interface{}) bool { // Where { "name": "Peter" }, // Where { "last_name": "Parker" }, // ) -func (c *MongoDBCollection) Remove(terms ...interface{}) bool { - - var multi interface{} +func (c *MongoDataSourceCollection) Remove(terms ...interface{}) bool { query := c.compileQuery(terms) - itop := len(terms) - - for i := 0; i < itop; i++ { - term := terms[i] - - switch term.(type) { - case multiFlag: - { - multi = term.(multiFlag) - } - } - } - - if multi != nil { - c.collection.RemoveAll(query) - } else { - c.collection.Remove(query) - } + c.collection.RemoveAll(query) return true } -// Updates all the items that match the conditions. See Update(). -func (c *MongoDBCollection) UpdateAll(terms ...interface{}) bool { - - terms = append(terms, multiFlag(true)) - - result := c.invoke("Update", terms) - - return result[0].Bool() -} - // Updates a single document matching the provided conditions. You can specify the modification type by using Set, Modify or Upsert. // // Example of assigning field values with Set: @@ -248,19 +209,15 @@ func (c *MongoDBCollection) UpdateAll(terms ...interface{}) bool { // Where { "name": "Roberto" }, // Upsert { "name": "Robert"}, // ) -func (c *MongoDBCollection) Update(terms ...interface{}) bool { +func (c *MongoDataSourceCollection) Update(terms ...interface{}) bool { var set interface{} var upsert interface{} var modify interface{} - var multi interface{} set = nil upsert = nil modify = nil - multi = nil - - // TODO: make use multiFlag query := c.compileQuery(terms) @@ -282,37 +239,17 @@ func (c *MongoDBCollection) Update(terms ...interface{}) bool { { modify = term.(Modify) } - case multiFlag: - { - multi = term.(multiFlag) - } } } - if multi != nil { - - if set != nil { - c.collection.UpdateAll(query, Item{"$set": set}) - return true - } - - if modify != nil { - c.collection.UpdateAll(query, modify) - return true - } - - } else { - - if set != nil { - c.collection.Update(query, Item{"$set": set}) - return true - } - - if modify != nil { - c.collection.Update(query, modify) - return true - } + if set != nil { + c.collection.UpdateAll(query, Item{"$set": set}) + return true + } + if modify != nil { + c.collection.UpdateAll(query, modify) + return true } if upsert != nil { @@ -323,8 +260,8 @@ func (c *MongoDBCollection) Update(terms ...interface{}) bool { return false } -// Calls a MongoDBCollection function by string. -func (c *MongoDBCollection) invoke(fn string, terms []interface{}) []reflect.Value { +// Calls a MongoDataSourceCollection function by string. +func (c *MongoDataSourceCollection) invoke(fn string, terms []interface{}) []reflect.Value { self := reflect.TypeOf(c) method, _ := self.MethodByName(fn) @@ -344,7 +281,7 @@ func (c *MongoDBCollection) invoke(fn string, terms []interface{}) []reflect.Val } // Returns the number of total items matching the provided conditions. -func (c *MongoDBCollection) Count(terms ...interface{}) int { +func (c *MongoDataSourceCollection) Count(terms ...interface{}) int { q := c.invoke("BuildQuery", terms) p := q[0].Interface().(*mgo.Query) @@ -372,7 +309,7 @@ func (c *MongoDBCollection) Count(terms ...interface{}) int { // Where { "age": 20 }, // }, // ) -func (c *MongoDBCollection) Find(terms ...interface{}) Item { +func (c *MongoDataSourceCollection) Find(terms ...interface{}) Item { var item Item @@ -393,7 +330,7 @@ func (c *MongoDBCollection) Find(terms ...interface{}) Item { // Returns a mgo.Query that matches the provided terms. // // This is actually a function that is only public because of the implementation of mongo.go but you should not use or rely on it. -func (c *MongoDBCollection) BuildQuery(terms ...interface{}) *mgo.Query { +func (c *MongoDataSourceCollection) BuildQuery(terms ...interface{}) *mgo.Query { var sort interface{} @@ -455,7 +392,7 @@ func (c *MongoDBCollection) BuildQuery(terms ...interface{}) *mgo.Query { // Where { "last_name": "Smith" }, // Limit(10), // ) -func (c *MongoDBCollection) FindAll(terms ...interface{}) []Item { +func (c *MongoDataSourceCollection) FindAll(terms ...interface{}) []Item { var items []Item var result []interface{} @@ -583,10 +520,10 @@ func (c *MongoDBCollection) FindAll(terms ...interface{}) []Item { // Executing external query. if relation["all"] == true { - value := relation["collection"].(*MongoDBCollection).invoke("FindAll", terms) + value := relation["collection"].(*MongoDataSourceCollection).invoke("FindAll", terms) item[relation["name"].(string)] = value[0].Interface().([]Item) } else { - value := relation["collection"].(*MongoDBCollection).invoke("Find", terms) + value := relation["collection"].(*MongoDataSourceCollection).invoke("Find", terms) item[relation["name"].(string)] = value[0].Interface().(Item) } @@ -599,12 +536,12 @@ func (c *MongoDBCollection) FindAll(terms ...interface{}) []Item { return items } -// Returns a new MongoDB object, this object can be then used to Connect() to the database and operate on Collections. +// Returns a new MongoDataSource object, this object can be then used to Connect() to the database and operate on Collections. // See db.DataSource{}. // // Example: // -// source := NewMongoDB(&DataSource { +// source := MongoSession(&DataSource { // Host: "localhost", // Database: "test", // User: "charly", @@ -622,29 +559,33 @@ func (c *MongoDBCollection) FindAll(terms ...interface{}) []Item { // people := db.Collection("people") // // result := people.Find(Where { "name": "José" }) -func NewMongoDB(config *DataSource) Database { - m := &MongoDB{} +func MongoSession(config DataSource) Database { + m := &MongoDataSource{} m.config = config return m } -// Switches the current session database to the provided name. See NewMongoDB(). -func (m *MongoDB) Use(database string) error { +// Switches the current session database to the provided name. See MongoSession(). +func (m *MongoDataSource) Use(database string) error { m.config.Database = database m.database = m.session.DB(m.config.Database) return nil } -// Returns a Collection from the currently active database given the name. See NewMongoDB(). -func (m *MongoDB) Collection(name string) Collection { - c := &MongoDBCollection{} +// Returns a Collection from the currently active database given the name. See MongoSession(). +func (m *MongoDataSource) Collection(name string) Collection { + c := &MongoDataSourceCollection{} c.parent = m c.collection = m.database.C(name) return c } -// Connects to the previously specified datasource. See NewMongoDB(). -func (m *MongoDB) Connect() error { +func (m *MongoDataSource) Driver() interface{} { + return m.session +} + +// Connects to the previously specified datasource. See MongoSession(). +func (m *MongoDataSource) Open() error { var err error connURL := &url.URL{Scheme: "mongodb"} @@ -677,13 +618,21 @@ func (m *MongoDB) Connect() error { } // Entirely drops the active database. -func (m *MongoDB) Drop() error { +func (m *MongoDataSource) Drop() error { err := m.database.DropDatabase() return err } +// Entirely drops the active database. +func (m *MongoDataSource) Close() error { + if m.session != nil { + m.session.Close() + } + return nil +} + // Returns all the collection names on the active database. -func (m *MongoDB) Collections() []string { +func (m *MongoDataSource) Collections() []string { names, _ := m.database.CollectionNames() return names } diff --git a/db/mongo_test.go b/db/mongo_test.go index 5981d438..601d15fa 100644 --- a/db/mongo_test.go +++ b/db/mongo_test.go @@ -10,11 +10,12 @@ import ( const mgHost = "10.0.0.11" const mgDatabase = "gotest" -func TestMgConnect(t *testing.T) { +func TestMgOpen(t *testing.T) { - db := NewMongoDB(&DataSource{Host: "0.0.0.0"}) + db := MongoSession(DataSource{Host: "0.0.0.0"}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { t.Logf("Got %t, this was intended.", err) @@ -26,9 +27,10 @@ func TestMgConnect(t *testing.T) { func TestMgAuthFail(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase, User: "unknown", Password: "fail"}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase, User: "unknown", Password: "fail"}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { t.Logf("Got %t, this was intended.", err) @@ -40,9 +42,10 @@ func TestMgAuthFail(t *testing.T) { func TestMgDrop(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -53,9 +56,10 @@ func TestMgDrop(t *testing.T) { func TestMgAppend(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -77,9 +81,10 @@ func TestMgAppend(t *testing.T) { func TestMgFind(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -96,9 +101,10 @@ func TestMgFind(t *testing.T) { } func TestMgDelete(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -116,9 +122,10 @@ func TestMgDelete(t *testing.T) { } func TestMgUpdate(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -138,9 +145,10 @@ func TestMgUpdate(t *testing.T) { func TestMgPopulate(t *testing.T) { var i int - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -189,9 +197,10 @@ func TestMgPopulate(t *testing.T) { } func TestMgRelation(t *testing.T) { - db := NewMongoDB(&DataSource{Host: mgHost, Database: mgDatabase}) + db := MongoSession(DataSource{Host: mgHost, Database: mgDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) diff --git a/db/mysql.go b/db/mysql.go index 4943c900..54b5b55f 100644 --- a/db/mysql.go +++ b/db/mysql.go @@ -34,14 +34,9 @@ import ( "strings" ) -/* -type Values []string -type Args []string -*/ - type myQuery struct { - Query []string - Args []string + Query []string + sqlArgs []string } func myCompile(terms []interface{}) *myQuery { @@ -55,18 +50,18 @@ func myCompile(terms []interface{}) *myQuery { { q.Query = append(q.Query, term.(string)) } - case Args: + case sqlArgs: { - for _, arg := range term.(Args) { - q.Args = append(q.Args, arg) + for _, arg := range term.(sqlArgs) { + q.sqlArgs = append(q.sqlArgs, arg) } } - case Values: + case sqlValues: { - args := make([]string, len(term.(Values))) - for i, arg := range term.(Values) { + args := make([]string, len(term.(sqlValues))) + for i, arg := range term.(sqlValues) { args[i] = "?" - q.Args = append(q.Args, arg) + q.sqlArgs = append(q.sqlArgs, arg) } q.Query = append(q.Query, "("+strings.Join(args, ", ")+")") } @@ -84,20 +79,22 @@ func myFields(names []string) string { return "(" + strings.Join(names, ", ") + ")" } -func myValues(values []string) Values { - ret := make(Values, len(values)) +func myValues(values []string) sqlValues { + ret := make(sqlValues, len(values)) for i, _ := range values { ret[i] = values[i] } return ret } -type MysqlDB struct { - config *DataSource +// Stores driver's session data. +type MysqlDataSource struct { session *sql.DB + config DataSource collections map[string]Collection } +// Returns all items from a query. func (t *MysqlTable) myFetchAll(rows sql.Rows) []Item { items := []Item{} @@ -161,7 +158,8 @@ func (t *MysqlTable) myFetchAll(rows sql.Rows) []Item { return items } -func (my *MysqlDB) myExec(method string, terms ...interface{}) sql.Rows { +// Executes a database/sql method. +func (my *MysqlDataSource) myExec(method string, terms ...interface{}) sql.Rows { sn := reflect.ValueOf(my.session) fn := sn.MethodByName(method) @@ -170,15 +168,15 @@ func (my *MysqlDB) myExec(method string, terms ...interface{}) sql.Rows { /* fmt.Printf("Q: %v\n", q.Query) - fmt.Printf("A: %v\n", q.Args) + fmt.Printf("A: %v\n", q.sqlArgs) */ - args := make([]reflect.Value, len(q.Args)+1) + args := make([]reflect.Value, len(q.sqlArgs)+1) args[0] = reflect.ValueOf(strings.Join(q.Query, " ")) - for i := 0; i < len(q.Args); i++ { - args[1+i] = reflect.ValueOf(q.Args[i]) + for i := 0; i < len(q.sqlArgs); i++ { + args[1+i] = reflect.ValueOf(q.sqlArgs[i]) } res := fn.Call(args) @@ -190,20 +188,31 @@ func (my *MysqlDB) myExec(method string, terms ...interface{}) sql.Rows { return res[0].Elem().Interface().(sql.Rows) } +// Represents a MySQL table. type MysqlTable struct { - parent *MysqlDB + parent *MysqlDataSource name string types map[string]reflect.Kind } -func NewMysqlDB(config *DataSource) Database { - m := &MysqlDB{} - m.config = config - m.collections = make(map[string]Collection) - return m +// Configures and returns a MySQL database session. +func MysqlSession(config DataSource) Database { + my := &MysqlDataSource{} + my.config = config + my.collections = make(map[string]Collection) + return my +} + +// Closes a previously opened MySQL database session. +func (my *MysqlDataSource) Close() error { + if my.session != nil { + return my.session.Close() + } + return nil } -func (my *MysqlDB) Connect() error { +// Tries to open a connection to the current MySQL session. +func (my *MysqlDataSource) Open() error { var err error if my.config.Host == "" { @@ -229,18 +238,26 @@ func (my *MysqlDB) Connect() error { return nil } -func (my *MysqlDB) Use(database string) error { +// Changes the active database. +func (my *MysqlDataSource) Use(database string) error { my.config.Database = database my.session.Query(fmt.Sprintf("USE %s", database)) return nil } -func (my *MysqlDB) Drop() error { +// Drops the current active database. +func (my *MysqlDataSource) Drop() error { my.session.Query(fmt.Sprintf("DROP DATABASE %s", my.config.Database)) return nil } -func (my *MysqlDB) Collections() []string { +// Returns a *sql.DB object that represents an internal session. +func (my *MysqlDataSource) Driver() interface{} { + return my.session +} + +// Returns the list of MySQL tables in the current database. +func (my *MysqlDataSource) Collections() []string { var collections []string var collection string rows, _ := my.session.Query("SHOW TABLES") @@ -253,6 +270,7 @@ func (my *MysqlDB) Collections() []string { return collections } +// Calls an internal function. func (t *MysqlTable) invoke(fn string, terms []interface{}) []reflect.Value { self := reflect.ValueOf(t) @@ -270,9 +288,10 @@ func (t *MysqlTable) invoke(fn string, terms []interface{}) []reflect.Value { return exec } -func (t *MysqlTable) compileSet(term Set) (string, Args) { +// A helper for preparing queries that use SET. +func (t *MysqlTable) compileSet(term Set) (string, sqlArgs) { sql := []string{} - args := Args{} + args := sqlArgs{} for key, arg := range term { sql = append(sql, fmt.Sprintf("%s = ?", key)) @@ -282,9 +301,10 @@ func (t *MysqlTable) compileSet(term Set) (string, Args) { return strings.Join(sql, ", "), args } -func (t *MysqlTable) compileConditions(term interface{}) (string, Args) { +// A helper for preparing queries that have conditions. +func (t *MysqlTable) compileConditions(term interface{}) (string, sqlArgs) { sql := []string{} - args := Args{} + args := sqlArgs{} switch term.(type) { case []interface{}: @@ -354,6 +374,7 @@ func (t *MysqlTable) compileConditions(term interface{}) (string, Args) { return "", args } +// Converts Where{} structures into SQL. func (t *MysqlTable) marshal(where Where) (string, []string) { for key, val := range where { @@ -373,6 +394,7 @@ func (t *MysqlTable) marshal(where Where) (string, []string) { return "", []string{} } +// Deletes all the rows in the table. func (t *MysqlTable) Truncate() bool { t.parent.myExec( @@ -383,49 +405,11 @@ func (t *MysqlTable) Truncate() bool { return false } +// Removes all the rows in the table that match certain conditions. func (t *MysqlTable) Remove(terms ...interface{}) bool { - terms = append(terms, Limit(1)) - - result := t.invoke("RemoveAll", terms) - - if len(result) > 0 { - return result[0].Interface().(bool) - } - - return false -} - -func (t *MysqlTable) Update(terms ...interface{}) bool { - terms = append(terms, Limit(1)) - - result := t.invoke("UpdateAll", terms) - - if len(result) > 0 { - return result[0].Interface().(bool) - } - - return false -} - -func (t *MysqlTable) RemoveAll(terms ...interface{}) bool { - limit := "" - offset := "" conditions, cargs := t.compileConditions(terms) - for _, term := range terms { - switch term.(type) { - case Limit: - { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) - } - case Offset: - { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) - } - } - } - if conditions == "" { conditions = "1 = 1" } @@ -434,18 +418,15 @@ func (t *MysqlTable) RemoveAll(terms ...interface{}) bool { "Query", fmt.Sprintf("DELETE FROM %s", myTable(t.name)), fmt.Sprintf("WHERE %s", conditions), cargs, - limit, offset, ) return true } -func (t *MysqlTable) UpdateAll(terms ...interface{}) bool { +// Modifies all the rows in the table that match certain conditions. +func (t *MysqlTable) Update(terms ...interface{}) bool { var fields string - var fargs Args - - limit := "" - offset := "" + var fargs sqlArgs conditions, cargs := t.compileConditions(terms) @@ -455,14 +436,6 @@ func (t *MysqlTable) UpdateAll(terms ...interface{}) bool { { fields, fargs = t.compileSet(term.(Set)) } - case Limit: - { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) - } - case Offset: - { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) - } } } @@ -474,12 +447,12 @@ func (t *MysqlTable) UpdateAll(terms ...interface{}) bool { "Query", fmt.Sprintf("UPDATE %s SET %s", myTable(t.name), fields), fargs, fmt.Sprintf("WHERE %s", conditions), cargs, - limit, offset, ) return true } +// Returns all the rows in the table that match certain conditions. func (t *MysqlTable) FindAll(terms ...interface{}) []Item { var itop int @@ -656,6 +629,7 @@ func (t *MysqlTable) FindAll(terms ...interface{}) []Item { return items } +// Returns the number of rows in the current table that match certain conditions. func (t *MysqlTable) Count(terms ...interface{}) int { terms = append(terms, Fields{"COUNT(1) AS _total"}) @@ -673,6 +647,7 @@ func (t *MysqlTable) Count(terms ...interface{}) int { return 0 } +// Returns the first row in the table that matches certain conditions. func (t *MysqlTable) Find(terms ...interface{}) Item { var item Item @@ -691,6 +666,7 @@ func (t *MysqlTable) Find(terms ...interface{}) Item { return item } +// Inserts a row into the table. func (t *MysqlTable) Append(items ...interface{}) bool { itop := len(items) @@ -720,7 +696,8 @@ func (t *MysqlTable) Append(items ...interface{}) bool { return true } -func (my *MysqlDB) Collection(name string) Collection { +// Returns a MySQL table object by name. +func (my *MysqlDataSource) Collection(name string) Collection { if collection, ok := my.collections[name]; ok == true { return collection diff --git a/db/mysql_test.go b/db/mysql_test.go index a7796e75..e60adc4a 100644 --- a/db/mysql_test.go +++ b/db/mysql_test.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "fmt" "github.com/kr/pretty" "math/rand" @@ -14,9 +15,10 @@ const myPassword = "gopass" func TestMyTruncate(t *testing.T) { - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -36,9 +38,10 @@ func TestMyTruncate(t *testing.T) { func TestMyAppend(t *testing.T) { - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -62,9 +65,10 @@ func TestMyAppend(t *testing.T) { func TestMyFind(t *testing.T) { - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -81,9 +85,10 @@ func TestMyFind(t *testing.T) { } func TestMyDelete(t *testing.T) { - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -101,9 +106,10 @@ func TestMyDelete(t *testing.T) { } func TestMyUpdate(t *testing.T) { - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -125,9 +131,10 @@ func TestMyUpdate(t *testing.T) { func TestMyPopulate(t *testing.T) { var i int - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -180,9 +187,10 @@ func TestMyPopulate(t *testing.T) { } func TestMyRelation(t *testing.T) { - db := NewMysqlDB(&DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -217,3 +225,21 @@ func TestMyRelation(t *testing.T) { fmt.Printf("%# v\n", pretty.Formatter(result)) } + +func TestCustom(t *testing.T) { + db := MysqlSession(DataSource{Host: myHost, Database: myDatabase, User: myUser, Password: myPassword}) + + err := db.Open() + defer db.Close() + + if err != nil { + panic(err) + } + + _, err = db.Driver().(*sql.DB).Query("SELECT NOW()") + + if err != nil { + panic(err) + } + +} diff --git a/db/postgresql.go b/db/postgresql.go index 16e46535..3fa548f6 100644 --- a/db/postgresql.go +++ b/db/postgresql.go @@ -35,8 +35,8 @@ import ( ) type pgQuery struct { - Query []string - Args []string + Query []string + sqlArgs []string } func pgCompile(terms []interface{}) *pgQuery { @@ -50,18 +50,18 @@ func pgCompile(terms []interface{}) *pgQuery { { q.Query = append(q.Query, term.(string)) } - case Args: + case sqlArgs: { - for _, arg := range term.(Args) { - q.Args = append(q.Args, arg) + for _, arg := range term.(sqlArgs) { + q.sqlArgs = append(q.sqlArgs, arg) } } - case Values: + case sqlValues: { - args := make([]string, len(term.(Values))) - for i, arg := range term.(Values) { + args := make([]string, len(term.(sqlValues))) + for i, arg := range term.(sqlValues) { args[i] = "?" - q.Args = append(q.Args, arg) + q.sqlArgs = append(q.sqlArgs, arg) } q.Query = append(q.Query, "("+strings.Join(args, ", ")+")") } @@ -79,16 +79,16 @@ func pgFields(names []string) string { return "(" + strings.Join(names, ", ") + ")" } -func pgValues(values []string) Values { - ret := make(Values, len(values)) +func pgValues(values []string) sqlValues { + ret := make(sqlValues, len(values)) for i, _ := range values { ret[i] = values[i] } return ret } -type PostgresqlDB struct { - config *DataSource +type PostgresqlDataSource struct { + config DataSource session *sql.DB collections map[string]Collection } @@ -156,7 +156,7 @@ func (t *PostgresqlTable) pgFetchAll(rows sql.Rows) []Item { return items } -func (pg *PostgresqlDB) pgExec(method string, terms ...interface{}) sql.Rows { +func (pg *PostgresqlDataSource) pgExec(method string, terms ...interface{}) sql.Rows { sn := reflect.ValueOf(pg.session) fn := sn.MethodByName(method) @@ -164,15 +164,15 @@ func (pg *PostgresqlDB) pgExec(method string, terms ...interface{}) sql.Rows { q := pgCompile(terms) //fmt.Printf("Q: %v\n", q.Query) - //fmt.Printf("A: %v\n", q.Args) + //fmt.Printf("A: %v\n", q.sqlArgs) qs := strings.Join(q.Query, " ") - args := make([]reflect.Value, len(q.Args)+1) + args := make([]reflect.Value, len(q.sqlArgs)+1) - for i := 0; i < len(q.Args); i++ { + for i := 0; i < len(q.sqlArgs); i++ { qs = strings.Replace(qs, "?", fmt.Sprintf("$%d", i+1), 1) - args[1+i] = reflect.ValueOf(q.Args[i]) + args[1+i] = reflect.ValueOf(q.sqlArgs[i]) } args[0] = reflect.ValueOf(qs) @@ -187,19 +187,27 @@ func (pg *PostgresqlDB) pgExec(method string, terms ...interface{}) sql.Rows { } type PostgresqlTable struct { - parent *PostgresqlDB + parent *PostgresqlDataSource name string types map[string]reflect.Kind } -func NewPostgresqlDB(config *DataSource) Database { - m := &PostgresqlDB{} +func PostgresqlSession(config DataSource) Database { + m := &PostgresqlDataSource{} m.config = config m.collections = make(map[string]Collection) return m } -func (pg *PostgresqlDB) Connect() error { +// Closes a previously opened MySQL database session. +func (pg *PostgresqlDataSource) Close() error { + if pg.session != nil { + return pg.session.Close() + } + return nil +} + +func (pg *PostgresqlDataSource) Open() error { var err error if pg.config.Host == "" { @@ -225,17 +233,22 @@ func (pg *PostgresqlDB) Connect() error { return nil } -func (pg *PostgresqlDB) Use(database string) error { +func (pg *PostgresqlDataSource) Use(database string) error { pg.config.Database = database - return pg.Connect() + return pg.Open() } -func (pg *PostgresqlDB) Drop() error { +func (pg *PostgresqlDataSource) Drop() error { pg.session.Query(fmt.Sprintf("DROP DATABASE %s", pg.config.Database)) return nil } -func (pg *PostgresqlDB) Collections() []string { +// Returns a *sql.DB object that represents an internal session. +func (pg *PostgresqlDataSource) Driver() interface{} { + return pg.session +} + +func (pg *PostgresqlDataSource) Collections() []string { var collections []string var collection string rows, _ := pg.session.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") @@ -265,9 +278,9 @@ func (t *PostgresqlTable) invoke(fn string, terms []interface{}) []reflect.Value return exec } -func (t *PostgresqlTable) compileSet(term Set) (string, Args) { +func (t *PostgresqlTable) compileSet(term Set) (string, sqlArgs) { sql := []string{} - args := Args{} + args := sqlArgs{} for key, arg := range term { sql = append(sql, fmt.Sprintf("%s = ?", key)) @@ -277,9 +290,9 @@ func (t *PostgresqlTable) compileSet(term Set) (string, Args) { return strings.Join(sql, ", "), args } -func (t *PostgresqlTable) compileConditions(term interface{}) (string, Args) { +func (t *PostgresqlTable) compileConditions(term interface{}) (string, sqlArgs) { sql := []string{} - args := Args{} + args := sqlArgs{} switch term.(type) { case []interface{}: @@ -379,50 +392,9 @@ func (t *PostgresqlTable) Truncate() bool { } func (t *PostgresqlTable) Remove(terms ...interface{}) bool { - // Does not support LIMIT - //terms = append(terms, Limit(1)) - - result := t.invoke("RemoveAll", terms) - - if len(result) > 0 { - return result[0].Interface().(bool) - } - - return false -} - -func (t *PostgresqlTable) Update(terms ...interface{}) bool { - // Does not support LIMIT - // terms = append(terms, Limit(1)) - - result := t.invoke("UpdateAll", terms) - - if len(result) > 0 { - return result[0].Interface().(bool) - } - - return false -} - -func (t *PostgresqlTable) RemoveAll(terms ...interface{}) bool { - limit := "" - offset := "" conditions, cargs := t.compileConditions(terms) - for _, term := range terms { - switch term.(type) { - case Limit: - { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) - } - case Offset: - { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) - } - } - } - if conditions == "" { conditions = "1 = 1" } @@ -431,18 +403,14 @@ func (t *PostgresqlTable) RemoveAll(terms ...interface{}) bool { "Query", fmt.Sprintf("DELETE FROM %s", pgTable(t.name)), fmt.Sprintf("WHERE %s", conditions), cargs, - limit, offset, ) return true } -func (t *PostgresqlTable) UpdateAll(terms ...interface{}) bool { +func (t *PostgresqlTable) Update(terms ...interface{}) bool { var fields string - var fargs Args - - limit := "" - offset := "" + var fargs sqlArgs conditions, cargs := t.compileConditions(terms) @@ -452,14 +420,6 @@ func (t *PostgresqlTable) UpdateAll(terms ...interface{}) bool { { fields, fargs = t.compileSet(term.(Set)) } - case Limit: - { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) - } - case Offset: - { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) - } } } @@ -471,7 +431,6 @@ func (t *PostgresqlTable) UpdateAll(terms ...interface{}) bool { "Query", fmt.Sprintf("UPDATE %s SET %s", pgTable(t.name), fields), fargs, fmt.Sprintf("WHERE %s", conditions), cargs, - limit, offset, ) return true @@ -717,7 +676,7 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool { return true } -func (pg *PostgresqlDB) Collection(name string) Collection { +func (pg *PostgresqlDataSource) Collection(name string) Collection { if collection, ok := pg.collections[name]; ok == true { return collection @@ -732,7 +691,7 @@ func (pg *PostgresqlDB) Collection(name string) Collection { rows := t.parent.pgExec( "Query", - "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", Args{t.name}, + "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", sqlArgs{t.name}, ) columns := t.pgFetchAll(rows) diff --git a/db/postgresql_test.go b/db/postgresql_test.go index 633c8c24..7c5f36fd 100644 --- a/db/postgresql_test.go +++ b/db/postgresql_test.go @@ -14,9 +14,10 @@ const pgPassword = "gopass" func TestPgTruncate(t *testing.T) { - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -36,9 +37,10 @@ func TestPgTruncate(t *testing.T) { func TestPgAppend(t *testing.T) { - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -62,9 +64,10 @@ func TestPgAppend(t *testing.T) { func TestPgFind(t *testing.T) { - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -81,9 +84,10 @@ func TestPgFind(t *testing.T) { } func TestPgDelete(t *testing.T) { - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -91,7 +95,7 @@ func TestPgDelete(t *testing.T) { col := db.Collection("people") - col.RemoveAll(Where{"name": "Juan"}) + col.Remove(Where{"name": "Juan"}) result := col.Find(Where{"name": "Juan"}) @@ -101,9 +105,10 @@ func TestPgDelete(t *testing.T) { } func TestPgUpdate(t *testing.T) { - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -125,9 +130,10 @@ func TestPgUpdate(t *testing.T) { func TestPgPopulate(t *testing.T) { var i int - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -180,9 +186,10 @@ func TestPgPopulate(t *testing.T) { } func TestPgRelation(t *testing.T) { - db := NewPostgresqlDB(&DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) + db := PostgresqlSession(DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) diff --git a/db/sqlite.go b/db/sqlite.go index 6b6cd0f3..a8f8283e 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -35,8 +35,8 @@ import ( ) type slQuery struct { - Query []string - Args []string + Query []string + sqlArgs []string } func slCompile(terms []interface{}) *slQuery { @@ -50,18 +50,18 @@ func slCompile(terms []interface{}) *slQuery { { q.Query = append(q.Query, term.(string)) } - case Args: + case sqlArgs: { - for _, arg := range term.(Args) { - q.Args = append(q.Args, arg) + for _, arg := range term.(sqlArgs) { + q.sqlArgs = append(q.sqlArgs, arg) } } - case Values: + case sqlValues: { - args := make([]string, len(term.(Values))) - for i, arg := range term.(Values) { + args := make([]string, len(term.(sqlValues))) + for i, arg := range term.(sqlValues) { args[i] = "?" - q.Args = append(q.Args, arg) + q.sqlArgs = append(q.sqlArgs, arg) } q.Query = append(q.Query, "("+strings.Join(args, ", ")+")") } @@ -79,16 +79,16 @@ func slFields(names []string) string { return "(" + strings.Join(names, ", ") + ")" } -func slValues(values []string) Values { - ret := make(Values, len(values)) +func slValues(values []string) sqlValues { + ret := make(sqlValues, len(values)) for i, _ := range values { ret[i] = values[i] } return ret } -type SqliteDB struct { - config *DataSource +type SqliteDataSource struct { + config DataSource session *sql.DB collections map[string]Collection } @@ -156,7 +156,7 @@ func (t *SqliteTable) slFetchAll(rows sql.Rows) []Item { return items } -func (sl *SqliteDB) slExec(method string, terms ...interface{}) sql.Rows { +func (sl *SqliteDataSource) slExec(method string, terms ...interface{}) sql.Rows { var rows sql.Rows @@ -167,15 +167,15 @@ func (sl *SqliteDB) slExec(method string, terms ...interface{}) sql.Rows { /* fmt.Printf("Q: %v\n", q.Query) - fmt.Printf("A: %v\n", q.Args) + fmt.Printf("A: %v\n", q.sqlArgs) */ - args := make([]reflect.Value, len(q.Args)+1) + args := make([]reflect.Value, len(q.sqlArgs)+1) args[0] = reflect.ValueOf(strings.Join(q.Query, " ")) - for i := 0; i < len(q.Args); i++ { - args[1+i] = reflect.ValueOf(q.Args[i]) + for i := 0; i < len(q.sqlArgs); i++ { + args[1+i] = reflect.ValueOf(q.sqlArgs[i]) } res := fn.Call(args) @@ -198,19 +198,24 @@ func (sl *SqliteDB) slExec(method string, terms ...interface{}) sql.Rows { } type SqliteTable struct { - parent *SqliteDB + parent *SqliteDataSource name string types map[string]reflect.Kind } -func NewSqliteDB(config *DataSource) Database { - m := &SqliteDB{} +func SqliteSession(config DataSource) Database { + m := &SqliteDataSource{} m.config = config m.collections = make(map[string]Collection) return m } -func (sl *SqliteDB) Connect() error { +// Returns a *sql.DB object that represents an internal session. +func (sl *SqliteDataSource) Driver() interface{} { + return sl.session +} + +func (sl *SqliteDataSource) Open() error { var err error if sl.config.Database == "" { @@ -228,18 +233,25 @@ func (sl *SqliteDB) Connect() error { return nil } -func (sl *SqliteDB) Use(database string) error { +func (sl *SqliteDataSource) Close() error { + if sl.session != nil { + return sl.session.Close() + } + return nil +} + +func (sl *SqliteDataSource) Use(database string) error { sl.config.Database = database sl.session.Query(fmt.Sprintf("USE %s", database)) return nil } -func (sl *SqliteDB) Drop() error { +func (sl *SqliteDataSource) Drop() error { sl.session.Query(fmt.Sprintf("DROP DATABASE %s", sl.config.Database)) return nil } -func (sl *SqliteDB) Collections() []string { +func (sl *SqliteDataSource) Collections() []string { var collections []string var collection string @@ -270,9 +282,9 @@ func (t *SqliteTable) invoke(fn string, terms []interface{}) []reflect.Value { return exec } -func (t *SqliteTable) compileSet(term Set) (string, Args) { +func (t *SqliteTable) compileSet(term Set) (string, sqlArgs) { sql := []string{} - args := Args{} + args := sqlArgs{} for key, arg := range term { sql = append(sql, fmt.Sprintf("%s = ?", key)) @@ -282,9 +294,9 @@ func (t *SqliteTable) compileSet(term Set) (string, Args) { return strings.Join(sql, ", "), args } -func (t *SqliteTable) compileConditions(term interface{}) (string, Args) { +func (t *SqliteTable) compileConditions(term interface{}) (string, sqlArgs) { sql := []string{} - args := Args{} + args := sqlArgs{} switch term.(type) { case []interface{}: @@ -382,50 +394,10 @@ func (t *SqliteTable) Truncate() bool { return false } - func (t *SqliteTable) Remove(terms ...interface{}) bool { - terms = append(terms, Limit(1)) - - result := t.invoke("RemoveAll", terms) - - if len(result) > 0 { - return result[0].Interface().(bool) - } - - return false -} - -func (t *SqliteTable) Update(terms ...interface{}) bool { - terms = append(terms, Limit(1)) - - result := t.invoke("UpdateAll", terms) - - if len(result) > 0 { - return result[0].Interface().(bool) - } - - return false -} - -func (t *SqliteTable) RemoveAll(terms ...interface{}) bool { - limit := "" - offset := "" conditions, cargs := t.compileConditions(terms) - for _, term := range terms { - switch term.(type) { - case Limit: - { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) - } - case Offset: - { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) - } - } - } - if conditions == "" { conditions = "1 = 1" } @@ -434,18 +406,14 @@ func (t *SqliteTable) RemoveAll(terms ...interface{}) bool { "Exec", fmt.Sprintf("DELETE FROM %s", slTable(t.name)), fmt.Sprintf("WHERE %s", conditions), cargs, - limit, offset, ) return true } -func (t *SqliteTable) UpdateAll(terms ...interface{}) bool { +func (t *SqliteTable) Update(terms ...interface{}) bool { var fields string - var fargs Args - - limit := "" - offset := "" + var fargs sqlArgs conditions, cargs := t.compileConditions(terms) @@ -455,14 +423,6 @@ func (t *SqliteTable) UpdateAll(terms ...interface{}) bool { { fields, fargs = t.compileSet(term.(Set)) } - case Limit: - { - limit = fmt.Sprintf("LIMIT %v", term.(Limit)) - } - case Offset: - { - offset = fmt.Sprintf("OFFSET %v", term.(Offset)) - } } } @@ -474,7 +434,6 @@ func (t *SqliteTable) UpdateAll(terms ...interface{}) bool { "Exec", fmt.Sprintf("UPDATE %s SET %s", slTable(t.name), fields), fargs, fmt.Sprintf("WHERE %s", conditions), cargs, - limit, offset, ) return true @@ -721,7 +680,7 @@ func (t *SqliteTable) Append(items ...interface{}) bool { return true } -func (sl *SqliteDB) Collection(name string) Collection { +func (sl *SqliteDataSource) Collection(name string) Collection { if collection, ok := sl.collections[name]; ok == true { return collection diff --git a/db/sqlite_test.go b/db/sqlite_test.go index 0686fd9a..480d94f7 100644 --- a/db/sqlite_test.go +++ b/db/sqlite_test.go @@ -11,9 +11,10 @@ const sqDatabase = "./dumps/gotest.sqlite3.db" func TestSqTruncate(t *testing.T) { - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -33,9 +34,10 @@ func TestSqTruncate(t *testing.T) { func TestSqAppend(t *testing.T) { - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -59,9 +61,10 @@ func TestSqAppend(t *testing.T) { func TestSqFind(t *testing.T) { - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -78,9 +81,10 @@ func TestSqFind(t *testing.T) { } func TestSqDelete(t *testing.T) { - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -88,8 +92,7 @@ func TestSqDelete(t *testing.T) { col := db.Collection("people") - // Remove() may not always work http://www.sqlite.org/compile.html#enable_update_delete_limit - col.RemoveAll(Where{"name": "Juan"}) + col.Remove(Where{"name": "Juan"}) result := col.Find(Where{"name": "Juan"}) @@ -99,9 +102,10 @@ func TestSqDelete(t *testing.T) { } func TestSqUpdate(t *testing.T) { - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -109,8 +113,7 @@ func TestSqUpdate(t *testing.T) { col := db.Collection("people") - // Update() may not always work http://www.sqlite.org/compile.html#enable_update_delete_limit - col.UpdateAll(Where{"name": "José"}, Set{"name": "Joseph"}) + col.Update(Where{"name": "José"}, Set{"name": "Joseph"}) result := col.Find(Where{"name": "Joseph"}) @@ -122,9 +125,10 @@ func TestSqUpdate(t *testing.T) { func TestSqPopulate(t *testing.T) { var i int - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) @@ -157,7 +161,7 @@ func TestSqPopulate(t *testing.T) { } // Lives in - db.Collection("people").UpdateAll( + db.Collection("people").Update( Where{"id": person["id"]}, Set{"place_code_id": int(rand.Float32() * float32(len(places)))}, ) @@ -177,9 +181,10 @@ func TestSqPopulate(t *testing.T) { } func TestSqRelation(t *testing.T) { - db := NewSqliteDB(&DataSource{Database: sqDatabase}) + db := SqliteSession(DataSource{Database: sqDatabase}) - err := db.Connect() + err := db.Open() + defer db.Close() if err != nil { panic(err) -- GitLab