diff --git a/.travis.yml b/.travis.yml index 1ee8077c710473f63e0b4cae15338e9f77ca4ffa..2653088a0af25d0f291cc01e8f7a0409ef651d47 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: -# - 1.1 // QL fails to compile on lower then go1.2. -# - 1.2 // "go get" SSL problems with go1.2. +# - 1.1 // Unsupported, QL fails to compile on go < 1.2. +# - 1.2 // Unsupported on travis because it fails to fetch some resources via HTTPs. # - 1.2.1 # - 1.2.2 - 1.3 @@ -10,35 +10,54 @@ go: - 1.3.2 - 1.3.3 - 1.4 + - 1.4.1 + - 1.4.2 -env: - - GOARCH=amd64 +env: GOARCH=amd64 TEST_HOST=127.0.0.1 UPPERIO_DB_DEBUG=1 install: - - sudo apt-get install bzr - # - go get github.com/cznic/ql/ql # ql command line util. - # - go install github.com/cznic/ql/ql # ql command line util. - - go get -t -d - # - go get upper.io/db/mongo - # - go get upper.io/db/mysql - # - go get upper.io/db/postgresql - # - go get upper.io/db/sqlite - # - go get upper.io/db/ql + - sudo apt-get install -y bzr make + - mkdir -p $GOPATH/src/upper.io/v2 + - mv $PWD $GOPATH/src/upper.io/v2/db + - cd $GOPATH/src/upper.io/v2/db + - ls -la + - go get -v github.com/cznic/ql/ql # ql command line util. + - go get -v -t -d + - go get -v -t -d upper.io/v2/db/mysql + - go get -v -t -d upper.io/v2/db/sqlite + - go get -v -t -d upper.io/v2/db/postgresql + - go get -v -t -d upper.io/v2/db/mongo + - go get -v -t -d upper.io/v2/db/ql + - go get -v github.com/pkieltyka/sqlx + - (cd $GOPATH/src/github.com/pkieltyka/sqlx && git pull -a && git checkout ptrs) # temporal fix + - (cp -r $GOPATH/src/github.com/pkieltyka/sqlx/* $GOPATH/src/github.com/jmoiron/sqlx) # temporal fix + - (cd $GOPATH/src/github.com/jmoiron/sqlx && go build -a && go install) + - export TRAVIS_BUILD_DIR=$GOPATH/src/upper.io/v2/db services: - mongodb before_script: - - mkdir -p $HOME/gopath/src/upper.io - - mv $HOME/gopath/src/github.com/upper/db $HOME/gopath/src/upper.io/db - - cd $HOME/gopath/src/upper.io/db + - mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root mysql - cat mysql/_dumps/setup.sql | mysql -uroot - cat mysql/_dumps/structs.sql | mysql -uupperio -pupperio upperio_tests + - cat postgresql/_dumps/setup.sql | psql -U postgres - cat postgresql/_dumps/structs.sql | PGPASSWORD="upperio" psql -U upperio upperio_tests + - mongo upperio_tests --eval 'db.addUser("upperio", "upperio")' - # - cat ql/_dumps/structs.sql | $GOPATH/bin/ql -db ql/_dumps/test.db + + - (cd mysql/_dumps && make) + - (cd postgresql/_dumps && make) + - (cd sqlite/_dumps && make) + - (cd ql/_dumps && make) + + - cat ql/_dumps/structs.sql | $GOPATH/bin/ql -db ql/_dumps/test.db script: - - go version - - go test -host 127.0.0.1 + - cd $GOPATH/src/upper.io/v2/db + - go test upper.io/v2/db/mysql -test.bench=. + - go test upper.io/v2/db/sqlite -test.bench=. + - go test upper.io/v2/db/ql -test.bench=. + - go test upper.io/v2/db/mongo -test.bench=. + - go test -test.v diff --git a/README.md b/README.md index dd36fa94ab91df5c75c788cca85cb5ad84a5d26b..d8da77fe8e80d0542f90c7354410922af9053d5c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# upper.io/db +# upper.io/v2/db + +*V2 IS CURRENTLY IN DEVELOPMENT AND NOT YET READY TO USE* <center> <img src="https://upper.io/images/icon.svg" width="256" /> @@ -8,15 +10,15 @@ ## The `db` package - + -`upper.io/db` is a [Go][2] package that allows developers to communicate with +`upper.io/v2/db` is a [Go][2] package that allows developers to communicate with different databases through the use of *adapters* that wrap well-supported database drivers. -## Is `upper.io/db` an ORM? +## Is `upper.io/v2/db` an ORM? -`upper.io/db` is not an ORM in the sense that it does not tell you how to +`upper.io/v2/db` is not an ORM in the sense that it does not tell you how to design your software or how to validate your data, instead it only focuses on being a tool that deals with common operations on different databases: @@ -29,26 +31,26 @@ res = col.Find(db.Cond{"name": "Max"}).Limit(10).Sort("-last_name") err = res.All(&people) ``` -In strict sense `upper.io/db` could be considered a really basic non-magical +In strict sense `upper.io/v2/db` could be considered a really basic non-magical ORM that rather stays out of the way. ## Supported databases - + -`upper.io/db` attempts to provide full compatiblity for [CRUD][2] operations +`upper.io/v2/db` attempts to provide full compatiblity for [CRUD][2] operations across adapters. Some other operations (such *transactions*) are supported only on specific database adapters, such as MySQL, PostgreSQL and SQLite. -* [MongoDB](https://upper.io/db/mongo) via [mgo](http://godoc.org/labix.org/v2/mgo) -* [MySQL](https://upper.io/db/mysql) via [mysql](https://github.com/go-sql-driver/mysql) -* [PostgreSQL](https://upper.io/db/postgresql) via [pq](https://github.com/lib/pq) -* [QL](https://upper.io/db/ql) via [ql](https://github.com/cznic/ql) -* [SQLite3](https://upper.io/db/sqlite) via [go-sqlite3](https://github.com/mattn/go-sqlite3) +* [MongoDB](https://upper.io/v2/db/mongo) via [mgo](http://godoc.org/labix.org/v2/mgo) +* [MySQL](https://upper.io/v2/db/mysql) via [mysql](https://github.com/go-sql-driver/mysql) +* [PostgreSQL](https://upper.io/v2/db/postgresql) via [pq](https://github.com/lib/pq) +* [QL](https://upper.io/v2/db/ql) via [ql](https://github.com/cznic/ql) +* [SQLite3](https://upper.io/v2/db/sqlite) via [go-sqlite3](https://github.com/mattn/go-sqlite3) ## User documentation -See the project page, recipes and user documentation at [upper.io/db][1]. +See the project page, recipes and user documentation at [upper.io/v2/db][1]. ## License @@ -73,6 +75,6 @@ See the project page, recipes and user documentation at [upper.io/db][1]. > OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION > WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -[1]: https://upper.io/db +[1]: https://upper.io/v2/db [2]: http://golang.org [3]: http://en.wikipedia.org/wiki/Create,_read,_update_and_delete diff --git a/main.go b/db.go similarity index 96% rename from main.go rename to db.go index b04c69af52d5f9f786516b55599d4749dccda6cb..d243413715430cea7ea9da1edff267a6cba8c5be 100644 --- a/main.go +++ b/db.go @@ -22,26 +22,27 @@ // Package db provides a single interface for interacting with different data // sources through the use of adapters that wrap well-known database drivers. // -// As of today, `upper.io/db` fully supports MySQL, PostgreSQL and SQLite (CRUD +// As of today, `upper.io/v2/db` fully supports MySQL, PostgreSQL and SQLite (CRUD // + Transactions) and provides partial support for MongoDB and QL (CRUD only). // // Usage: // // import( // // Main package. -// "upper.io/db" +// "upper.io/v2/db" // // PostgreSQL adapter. -// "upper.io/db/postgresql" +// "upper.io/v2/db/postgresql" // ) // -// `upper.io/db` is not an ORM and thus does not impose any hard restrictions +// `upper.io/v2/db` is not an ORM and thus does not impose any hard restrictions // on data structures: // // // This code works the same for all supported databases. // var people []Person // res = col.Find(db.Cond{"name": "Max"}).Limit(2).Sort("-input") // err = res.All(&people) -package db // import "upper.io/db" + +package db // import "upper.io/v2/db" // Cond is a map used to define conditions passed to `db.Collection.Find()` and // `db.Result.Where()`. @@ -100,6 +101,10 @@ type Func struct { // } type And []interface{} +func (a And) And(exp ...interface{}) And { + return append(a, exp...) +} + // Or is an array of interfaced that is used to join two or more expressions // under logical disjunction, it accepts `db.Cond{}`, `db.And{}`, `db.Raw{}` // and other `db.Or{}` values. @@ -113,6 +118,10 @@ type And []interface{} // } type Or []interface{} +func (o Or) Or(exp ...interface{}) Or { + return append(o, exp...) +} + // Raw holds chunks of data to be passed to the database without any filtering. // Use with care. // diff --git a/main_test.go b/db_test.go similarity index 80% rename from main_test.go rename to db_test.go index a5b91f8cce7718a7a529ed147684d4545fa9cacc..09226836b54a56edf145f75e7e8ab1003ec4457a 100644 --- a/main_test.go +++ b/db_test.go @@ -24,90 +24,102 @@ package db_test import ( "database/sql" "errors" - "flag" + "fmt" "log" + "os" "reflect" - "strconv" "testing" "time" + "github.com/jmoiron/sqlx" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" - "upper.io/db" - _ "upper.io/db/mongo" - _ "upper.io/db/mysql" - _ "upper.io/db/postgresql" - // Temporary removing QL. It includes a _solaris.go file that produces - // compile time errors on < go1.3. - // _ "upper.io/db/ql" - _ "upper.io/db/sqlite" + "upper.io/v2/db" + "upper.io/v2/db/mongo" + + "upper.io/v2/db/mysql" + "upper.io/v2/db/postgresql" + "upper.io/v2/db/ql" + "upper.io/v2/db/sqlite" ) var wrappers = []string{ - `sqlite`, - `mysql`, - `postgresql`, - `mongo`, - // `ql`, + sqlite.Adapter, + mysql.Adapter, + postgresql.Adapter, + mongo.Adapter, + ql.Adapter, } const ( - TestAllWrappers = `all` + testAllWrappers = `all` ) var ( errDriverErr = errors.New(`Driver error`) ) -var settings map[string]*db.Settings +var settings map[string]db.ConnectionURL func init() { - // Getting host from the environment. - host := flag.String("host", "testserver.local", "Testing server address.") - wrapper := flag.String("wrapper", "all", "Wrappers to test.") + // Getting settings from the environment. + + var host string + if host = os.Getenv("TEST_HOST"); host == "" { + host = "localhost" + } - flag.Parse() + var wrapper string + if wrapper = os.Getenv("TEST_WRAPPER"); wrapper == "" { + wrapper = testAllWrappers + } - log.Printf("Running tests against host %s.\n", *host) + log.Printf("Running tests against host %s.\n", host) - settings = map[string]*db.Settings{ - `sqlite`: &db.Settings{ + settings = map[string]db.ConnectionURL{ + `sqlite`: &sqlite.ConnectionURL{ Database: `upperio_tests.db`, }, - `mongo`: &db.Settings{ + `mongo`: &mongo.ConnectionURL{ Database: `upperio_tests`, - Host: *host, + Address: db.Host(host), User: `upperio`, Password: `upperio`, }, - `mysql`: &db.Settings{ + `mysql`: &mysql.ConnectionURL{ Database: `upperio_tests`, - Host: *host, + Address: db.Host(host), User: `upperio`, Password: `upperio`, + Options: map[string]string{ + "parseTime": "true", + }, }, - `postgresql`: &db.Settings{ + `postgresql`: &postgresql.ConnectionURL{ Database: `upperio_tests`, - Host: *host, + Address: db.Host(host), User: `upperio`, Password: `upperio`, + Options: map[string]string{ + "timezone": "UTC", + }, }, - `ql`: &db.Settings{ - Database: `file://upperio_test.ql`, + `ql`: &ql.ConnectionURL{ + Database: `upperio_test.ql`, }, } - if *wrapper != TestAllWrappers { - wrappers = []string{*wrapper} - log.Printf("Testing wrapper %s.", *wrapper) + if wrapper != testAllWrappers { + wrappers = []string{wrapper} + log.Printf("Testing wrapper %s.", wrapper) } } var setupFn = map[string]func(driver interface{}) error{ `mongo`: func(driver interface{}) error { - if mgod, ok := driver.(*mgo.Session); ok == true { + if mgod, ok := driver.(*mgo.Session); ok { var col *mgo.Collection col = mgod.DB("upperio_tests").C("birthdays") col.DropCollection() @@ -125,17 +137,17 @@ var setupFn = map[string]func(driver interface{}) error{ return errDriverErr }, `postgresql`: func(driver interface{}) error { - if sqld, ok := driver.(*sql.DB); ok == true { + if sqld, ok := driver.(*sqlx.DB); ok { var err error - _, err = sqld.Exec(`DROP TABLE IF EXISTS birthdays`) + _, err = sqld.Exec(`DROP TABLE IF EXISTS "birthdays"`) if err != nil { return err } _, err = sqld.Exec(`CREATE TABLE "birthdays" ( "id" serial primary key, "name" CHARACTER VARYING(50), - "born" TIMESTAMP, + "born" TIMESTAMP WITH TIME ZONE, "born_ut" INT )`) if err != nil { @@ -161,7 +173,7 @@ var setupFn = map[string]func(driver interface{}) error{ } _, err = sqld.Exec(`CREATE TABLE "is_even" ( "input" NUMERIC, - "is_even" INT + "is_even" BOOL )`) if err != nil { return err @@ -172,8 +184,8 @@ var setupFn = map[string]func(driver interface{}) error{ return err } _, err = sqld.Exec(`CREATE TABLE "CaSe_TesT" ( - "ID" SERIAL PRIMARY KEY, - "Case_Test" VARCHAR(60) + "id" SERIAL PRIMARY KEY, + "case_test" VARCHAR(60) )`) if err != nil { return err @@ -181,17 +193,17 @@ var setupFn = map[string]func(driver interface{}) error{ return nil } - return errDriverErr + return fmt.Errorf("Expecting *sqlx.DB got %T (%#v).", driver, driver) }, `mysql`: func(driver interface{}) error { - if sqld, ok := driver.(*sql.DB); ok == true { + if sqld, ok := driver.(*sqlx.DB); ok { var err error - _, err = sqld.Exec(`DROP TABLE IF EXISTS birthdays`) + _, err = sqld.Exec(`DROP TABLE IF EXISTS ` + "`" + `birthdays` + "`" + ``) if err != nil { return err } - _, err = sqld.Exec(`CREATE TABLE birthdays ( + _, err = sqld.Exec(`CREATE TABLE ` + "`" + `birthdays` + "`" + ` ( id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), name VARCHAR(50), born DATE, @@ -201,11 +213,11 @@ var setupFn = map[string]func(driver interface{}) error{ return err } - _, err = sqld.Exec(`DROP TABLE IF EXISTS fibonacci`) + _, err = sqld.Exec(`DROP TABLE IF EXISTS ` + "`" + `fibonacci` + "`" + ``) if err != nil { return err } - _, err = sqld.Exec(`CREATE TABLE fibonacci ( + _, err = sqld.Exec(`CREATE TABLE ` + "`" + `fibonacci` + "`" + ` ( id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), input BIGINT(20) UNSIGNED NOT NULL, output BIGINT(20) UNSIGNED NOT NULL @@ -214,11 +226,11 @@ var setupFn = map[string]func(driver interface{}) error{ return err } - _, err = sqld.Exec(`DROP TABLE IF EXISTS is_even`) + _, err = sqld.Exec(`DROP TABLE IF EXISTS ` + "`" + `is_even` + "`" + ``) if err != nil { return err } - _, err = sqld.Exec(`CREATE TABLE is_even ( + _, err = sqld.Exec(`CREATE TABLE ` + "`" + `is_even` + "`" + ` ( input BIGINT(20) UNSIGNED NOT NULL, is_even TINYINT(1) ) CHARSET=utf8`) @@ -226,13 +238,13 @@ var setupFn = map[string]func(driver interface{}) error{ return err } - _, err = sqld.Exec(`DROP TABLE IF EXISTS CaSe_TesT`) + _, err = sqld.Exec(`DROP TABLE IF EXISTS ` + "`" + `CaSe_TesT` + "`" + ``) if err != nil { return err } - _, err = sqld.Exec(`CREATE TABLE CaSe_TesT ( - ID BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(ID), - Case_Test VARCHAR(60) + _, err = sqld.Exec(`CREATE TABLE ` + "`" + `CaSe_TesT` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + case_test VARCHAR(60) ) CHARSET=utf8`) if err != nil { return err @@ -240,10 +252,10 @@ var setupFn = map[string]func(driver interface{}) error{ return nil } - return errDriverErr + return fmt.Errorf("Expecting *sqlx.DB got %T (%#v).", driver, driver) }, `sqlite`: func(driver interface{}) error { - if sqld, ok := driver.(*sql.DB); ok == true { + if sqld, ok := driver.(*sqlx.DB); ok { var err error _, err = sqld.Exec(`DROP TABLE IF EXISTS "birthdays"`) @@ -253,7 +265,7 @@ var setupFn = map[string]func(driver interface{}) error{ _, err = sqld.Exec(`CREATE TABLE "birthdays" ( "id" INTEGER PRIMARY KEY, "name" VARCHAR(50) DEFAULT NULL, - "born" VARCHAR(12) DEFAULT NULL, + "born" DATETIME DEFAULT NULL, "born_ut" INTEGER )`) if err != nil { @@ -290,8 +302,8 @@ var setupFn = map[string]func(driver interface{}) error{ return err } _, err = sqld.Exec(`CREATE TABLE "CaSe_TesT" ( - "ID" INTEGER PRIMARY KEY, - "Case_Test" VARCHAR + "id" INTEGER PRIMARY KEY, + "case_test" VARCHAR )`) if err != nil { return err @@ -302,7 +314,7 @@ var setupFn = map[string]func(driver interface{}) error{ return errDriverErr }, `ql`: func(driver interface{}) error { - if sqld, ok := driver.(*sql.DB); ok == true { + if sqld, ok := driver.(*sqlx.DB); ok { var err error var tx *sql.Tx @@ -356,7 +368,7 @@ var setupFn = map[string]func(driver interface{}) error{ } _, err = tx.Exec(`CREATE TABLE CaSe_TesT ( - Case_Test string + case_test string )`) if err != nil { return err @@ -373,9 +385,9 @@ var setupFn = map[string]func(driver interface{}) error{ } type birthday struct { - Name string // `db:"name"` // Must match by name. - Born time.Time // `db:"born" // Must match by name. - BornUT *timeType `db:"born_ut"` + Name string `db:"name"` + Born time.Time `db:"born"` + BornUT timeType `db:"born_ut"` OmitMe bool `json:"omit_me" db:"-" bson:"-"` } @@ -383,30 +395,30 @@ type fibonacci struct { Input uint64 `db:"input"` Output uint64 `db:"output"` // Test for BSON option. - OmitMe bool `json:"omitme" db:",bson,omitempty" bson:"omit_me,omitempty"` + OmitMe bool `json:"omit_me" db:"omit_me,bson,omitempty" bson:"omit_me,omitempty"` } type oddEven struct { // Test for JSON option. - Input int `json:"input"` + Input int `json:"input" db:"input"` // Test for JSON option. // The "bson" tag is required by mgo. - IsEven bool `json:"is_even" db:",json" bson:"is_even"` + IsEven bool `json:"is_even" db:"is_even,json" bson:"is_even"` OmitMe bool `json:"omit_me" db:"-" bson:"-"` } // Struct that relies on explicit mapping. type mapE struct { - ID uint `db:"ID,omitempty" bson:"-"` + ID uint `db:"id,omitempty" bson:"-"` MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` - CaseTest string `db:"Case_Test" bson:"Case_Test"` + CaseTest string `db:"case_test" bson:"case_test"` } // Struct that will fallback to default mapping. type mapN struct { - ID uint `db:",omitempty"` - MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` - Casetest string + ID uint `db:"id,omitempty"` + MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` + Case_TEST string `db:"case_test"` } // Struct for testing marshalling. @@ -417,27 +429,30 @@ type timeType struct { } // time.Time -> unix timestamp -func (u *timeType) MarshalDB() (interface{}, error) { +func (u timeType) MarshalDB() (interface{}, error) { return u.value.Unix(), nil } // unix timestamp -> time.Time func (u *timeType) UnmarshalDB(v interface{}) error { - var i int + var unixTime int64 switch t := v.(type) { - case string: - i, _ = strconv.Atoi(t) + case int64: + unixTime = t default: return db.ErrUnsupportedValue } - t := time.Unix(int64(i), 0) + t := time.Unix(unixTime, 0).In(time.UTC) *u = timeType{t} return nil } +var _ db.Marshaler = timeType{} +var _ db.Unmarshaler = &timeType{} + func even(i int) bool { if i%2 == 0 { return true @@ -457,11 +472,13 @@ func fib(i uint64) uint64 { func TestOpen(t *testing.T) { var err error for _, wrapper := range wrappers { + t.Logf("Testing wrapper: %q", wrapper) + if settings[wrapper] == nil { t.Fatalf(`No such settings entry for wrapper %s.`, wrapper) } else { var sess db.Database - sess, err = db.Open(wrapper, *settings[wrapper]) + sess, err = db.Open(wrapper, settings[wrapper]) if err != nil { t.Fatalf(`Test for wrapper %s failed: %q`, wrapper, err) } @@ -476,12 +493,14 @@ func TestOpen(t *testing.T) { func TestSetup(t *testing.T) { var err error for _, wrapper := range wrappers { + t.Logf("Testing wrapper: %q", wrapper) + if settings[wrapper] == nil { t.Fatalf(`No such settings entry for wrapper %s.`, wrapper) } else { var sess db.Database - sess, err = db.Open(wrapper, *settings[wrapper]) + sess, err = db.Open(wrapper, settings[wrapper]) if err != nil { t.Fatalf(`Test for wrapper %s failed: %q`, wrapper, err) } @@ -513,21 +532,23 @@ func TestSimpleCRUD(t *testing.T) { t.Fatalf(`No such settings entry for wrapper %s.`, wrapper) } else { + t.Logf("Testing wrapper: %q", wrapper) + var sess db.Database - sess, err = db.Open(wrapper, *settings[wrapper]) + sess, err = db.Open(wrapper, settings[wrapper]) if err != nil { t.Fatalf(`Test for wrapper %s failed: %q`, wrapper, err) } defer sess.Close() - born := time.Date(1941, time.January, 5, 0, 0, 0, 0, time.Local) + born := time.Date(1941, time.January, 5, 0, 0, 0, 0, time.UTC) controlItem = birthday{ Name: "Hayao Miyazaki", Born: born, - BornUT: &timeType{born}, + BornUT: timeType{born}, } col, err := sess.Collection(`birthdays`) @@ -575,9 +596,13 @@ func TestSimpleCRUD(t *testing.T) { t.Fatalf("%s One(): %s", wrapper, err) } + if wrapper == `sqlite` { + // SQLite does not save time zone info, so you have to do this by hand. + testItem.Born = testItem.Born.In(time.UTC) + } + if reflect.DeepEqual(testItem, controlItem) == false { - t.Errorf("%s: testItem: %v (ts: %v)\n", wrapper, testItem, testItem.BornUT.value.Unix()) - t.Errorf("%s: controlItem: %v (ts: %v)\n", wrapper, controlItem, controlItem.BornUT.value.Unix()) + t.Errorf("%s: controlItem (inserted): %v (ts: %v)\n", wrapper, controlItem, controlItem.BornUT.value.Unix()) t.Fatalf("%s: Structs are different", wrapper) } @@ -592,6 +617,10 @@ func TestSimpleCRUD(t *testing.T) { } for _, testItem = range testItems { + if wrapper == `sqlite` { + // SQLite does not save time zone info, so you have to do this by hand. + testItem.Born = testItem.Born.In(time.UTC) + } if reflect.DeepEqual(testItem, controlItem) == false { t.Errorf("%s: testItem: %v\n", wrapper, testItem) t.Errorf("%s: controlItem: %v\n", wrapper, controlItem) @@ -606,7 +635,15 @@ func TestSimpleCRUD(t *testing.T) { t.Fatalf(`Could not update with wrapper %s: %q`, wrapper, err) } - res.One(&testItem) + err = res.One(&testItem) + if err != nil { + t.Fatalf("%s One(): %s", wrapper, err) + } + + if wrapper == `sqlite` { + // SQLite does not save time zone info, so you have to do this by hand. + testItem.Born = testItem.Born.In(time.UTC) + } if reflect.DeepEqual(testItem, controlItem) == false { t.Fatalf("Struct is different with wrapper %s.", wrapper) @@ -644,12 +681,15 @@ func TestFibonacci(t *testing.T) { var total uint64 for _, wrapper := range wrappers { + t.Logf("Testing wrapper: %q", wrapper) + if settings[wrapper] == nil { t.Fatalf(`No such settings entry for wrapper %s.`, wrapper) } else { + var sess db.Database - sess, err = db.Open(wrapper, *settings[wrapper]) + sess, err = db.Open(wrapper, settings[wrapper]) if err != nil { t.Fatalf(`Test for wrapper %s failed: %q`, wrapper, err) } @@ -856,12 +896,14 @@ func TestEven(t *testing.T) { var err error for _, wrapper := range wrappers { + t.Logf("Testing wrapper: %q", wrapper) + if settings[wrapper] == nil { t.Fatalf(`No such settings entry for wrapper %s.`, wrapper) } else { var sess db.Database - sess, err = db.Open(wrapper, *settings[wrapper]) + sess, err = db.Open(wrapper, settings[wrapper]) if err != nil { t.Fatalf(`Test for wrapper %s failed: %q`, wrapper, err) } @@ -965,26 +1007,6 @@ func TestEven(t *testing.T) { t.Fatalf("Expecting no data with wrapper %s. Got: %v\n", wrapper, item) } } - - // Testing (deprecated) "field" tag. - for { - // Testing named inputs (using tags). - var item struct { - Value uint `field:"input"` - } - err = res.Next(&item) - if err != nil { - if err == db.ErrNoMoreRows { - break - } else { - t.Fatalf(`%s: %v`, wrapper, err) - } - } - if item.Value%2 == 0 { - t.Fatalf("Expecting no data with wrapper %s. Got: %v\n", wrapper, item) - } - } - } } @@ -1000,22 +1022,21 @@ func TestExplicitAndDefaultMapping(t *testing.T) { var testN mapN for _, wrapper := range wrappers { + t.Logf("Testing wrapper: %q", wrapper) if settings[wrapper] == nil { t.Fatalf(`No such settings entry for wrapper %s.`, wrapper) } else { - if sess, err = db.Open(wrapper, *settings[wrapper]); err != nil { + if sess, err = db.Open(wrapper, settings[wrapper]); err != nil { t.Fatalf(`Test for wrapper %s failed: %q`, wrapper, err) } defer sess.Close() - col, err = sess.Collection("Case_Test") - if col, err = sess.Collection("CaSe_TesT"); err != nil { if wrapper == `mongo` && err == db.ErrCollectionDoesNotExist { - // Nothing, it's expected. + // Nothing, this is expected. } else { t.Fatal(err) } @@ -1023,7 +1044,7 @@ func TestExplicitAndDefaultMapping(t *testing.T) { if err = col.Truncate(); err != nil { if wrapper == `mongo` { - // Nothing, it's expected. + // Nothing, this is expected. } else { t.Fatal(err) } @@ -1038,10 +1059,10 @@ func TestExplicitAndDefaultMapping(t *testing.T) { t.Fatal(err) } - res = col.Find(db.Cond{"Case_Test": "Hello!"}) + res = col.Find(db.Cond{"case_test": "Hello!"}) if wrapper == `ql` { - res = res.Select(`id() as ID`, `Case_Test`) + res = res.Select(`id() as id`, `case_test`) } if err = res.One(&testE); err != nil { @@ -1060,7 +1081,7 @@ func TestExplicitAndDefaultMapping(t *testing.T) { // Testing default mapping. testN = mapN{ - Casetest: "World!", + Case_TEST: "World!", } if _, err = col.Append(testN); err != nil { @@ -1068,14 +1089,13 @@ func TestExplicitAndDefaultMapping(t *testing.T) { } if wrapper == `mongo` { - // We don't have this kind of control with mongodb. - res = col.Find(db.Cond{"casetest": "World!"}) + res = col.Find(db.Cond{"case_test": "World!"}) } else { - res = col.Find(db.Cond{"Case_Test": "World!"}) + res = col.Find(db.Cond{"case_test": "World!"}) } if wrapper == `ql` { - res = res.Select(`id() as ID`, `Case_Test`) + res = res.Select(`id() as id`, `case_test`) } if err = res.One(&testN); err != nil { diff --git a/error.go b/error.go index 631f435a5fea2aeab1dc489b92da7e8dafc61268..4110560b965ef194ec1cb14ce6995bac3242134b 100644 --- a/error.go +++ b/error.go @@ -47,7 +47,7 @@ var ( ErrUnsupportedDestination = errors.New(`Unsupported destination type.`) ErrUnsupportedType = errors.New(`This type does not support marshaling.`) ErrUnsupportedValue = errors.New(`This value does not support unmarshaling.`) - ErrUnknownConditionType = errors.New(`Arguments of type %s can't be used as constraints.`) + ErrUnknownConditionType = errors.New(`Arguments of type %T can't be used as constraints.`) ) // Deprecated but kept for backwards compatibility. See: https://github.com/upper/db/issues/18 diff --git a/mongo/README.md b/mongo/README.md index ded918a24a60942dae3eea3bc135c390acce108f..3627fb228ec856f79b4233f77cf6ec5e345d7f63 100644 --- a/mongo/README.md +++ b/mongo/README.md @@ -1,6 +1,6 @@ -# MongoDB adapter for upper.io/db +# MongoDB adapter for upper.io/v2/db Please read the full docs, acknowledgements and examples at -[https://upper.io/db/wrappers/mongo][1]. +[https://upper.io/v2/db/wrappers/mongo][1]. -[1]: https://upper.io/db/wrappers/mongo +[1]: https://upper.io/v2/db/wrappers/mongo diff --git a/mongo/_example/main.go b/mongo/_example/main.go index 88420f8557180d90ea5b6bc3b1eed0d94ec6a777..b6635258c7c86f0cfeed7c06a736d5d4e8be0d49 100644 --- a/mongo/_example/main.go +++ b/mongo/_example/main.go @@ -5,8 +5,8 @@ import ( "log" "time" - "upper.io/db" // Imports the main db package. - _ "upper.io/db/mongo" // Imports the mongo adapter. + "upper.io/v2/db" // Imports the main db package. + _ "upper.io/v2/db/mongo" // Imports the mongo adapter. ) var settings = db.Settings{ diff --git a/mongo/collection.go b/mongo/collection.go index df339a014f64098693ec87c89b2dcc52023cd5a5..ca0dfa36466b9cc2e3adcad3011594c7253e6736 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -30,7 +30,7 @@ import ( "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" - "upper.io/db" + "upper.io/v2/db" ) // Mongodb Collection diff --git a/mongo/connection.go b/mongo/connection.go index cb4917a0346170312d78c04f34063d20d321b4ae..f0a7f9cb90aaf691d1713742c6bb6bd6d8eb4fbd 100644 --- a/mongo/connection.go +++ b/mongo/connection.go @@ -26,7 +26,7 @@ import ( "net/url" "strings" - "upper.io/db" + "upper.io/v2/db" ) const connectionScheme = `mongodb` diff --git a/mongo/connection_test.go b/mongo/connection_test.go index 3e3f4328d0495e2e74d43c2f5a71ba374948288a..2ca33718fea6f5a2335b6667ca9f52722876302d 100644 --- a/mongo/connection_test.go +++ b/mongo/connection_test.go @@ -24,7 +24,7 @@ package mongo import ( "testing" - "upper.io/db" + "upper.io/v2/db" ) func TestConnectionURL(t *testing.T) { diff --git a/mongo/database.go b/mongo/database.go index b5148e2b97e5a6556c67d6e1b33bf9458665ccd4..bc1d4bf56e4ac4873c4c611c6aa7a45addb7e7cb 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -19,7 +19,7 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package mongo +package mongo // import "upper.io/v2/db/mongo" import ( "fmt" @@ -29,7 +29,7 @@ import ( "time" "gopkg.in/mgo.v2" - "upper.io/db" + "upper.io/v2/db" ) const Adapter = `mongo` diff --git a/mongo/database_test.go b/mongo/database_test.go index 1cd42677cd74b063d9b1a740e873031027a00bf1..8bfeef0961e12d1612e512df5fa616deff923efa 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -24,7 +24,6 @@ package mongo import ( "errors" - "flag" "math/rand" "os" "reflect" @@ -35,7 +34,7 @@ import ( "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" "menteslibres.net/gosexy/to" - "upper.io/db" + "upper.io/v2/db" ) // Wrapper settings. @@ -52,7 +51,7 @@ var settings = ConnectionURL{ Password: password, } -var host = flag.String("host", "testserver.local", "Testing server address.") +var host string // Structure for testing conversions and datatypes. type testValuesStruct struct { @@ -127,8 +126,11 @@ func init() { time.Second * time.Duration(7331), } - flag.Parse() - settings.Address = db.ParseAddress(*host) + if host = os.Getenv("TEST_HOST"); host == "" { + host = "localhost" + } + + settings.Address = db.ParseAddress(host) } // Enabling outputting some information to stdout, useful for development. @@ -155,7 +157,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: username, Password: password, } @@ -169,7 +171,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: username, Password: "fail", } @@ -181,7 +183,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong database. wrongSettings = db.Settings{ Database: "fail", - Host: *host, + Host: host, User: username, Password: password, } @@ -193,7 +195,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: "fail", Password: password, } diff --git a/mongo/result.go b/mongo/result.go index 8e6269b3575a7863f5a29c74a6a5ff0ccef24049..5609815e238d0f0910b26df6440c86175616ccaa 100644 --- a/mongo/result.go +++ b/mongo/result.go @@ -29,7 +29,7 @@ import ( "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" - "upper.io/db" + "upper.io/v2/db" ) type Result struct { diff --git a/mysql/README.md b/mysql/README.md index c2262cd3692ecb1a79203a5112c307213ca32eb8..6cdb62feefd25de2c66e190b1f7f4b95a10ea148 100644 --- a/mysql/README.md +++ b/mysql/README.md @@ -1,7 +1,7 @@ -# MySQL adapter for upper.io/db +# MySQL adapter for upper.io/v2/db See the full docs, acknowledgements and examples at -[https://upper.io/db/mysql][1] +[https://upper.io/v2/db/mysql][1] -[1]: https://upper.io/db/mysql +[1]: https://upper.io/v2/db/mysql diff --git a/mysql/_dumps/Makefile b/mysql/_dumps/Makefile index 5e9b4f17c51413eec7ced1a74cea64aa3e70e253..fb1b4e28247a353505402c796fec0018420ad638 100644 --- a/mysql/_dumps/Makefile +++ b/mysql/_dumps/Makefile @@ -1,2 +1,4 @@ +TEST_HOST ?= 127.0.0.1 + all: - cat structs.sql | mysql -uupperio -pupperio upperio_tests -htestserver.local + cat structs.sql | mysql -uupperio -pupperio upperio_tests -h$(TEST_HOST) diff --git a/mysql/_dumps/structs.sql b/mysql/_dumps/structs.sql index 09d4e89be6037a99822b77182d79a6a7e712826e..f0a298becec0779601a27f598b9f991406918544 100644 --- a/mysql/_dumps/structs.sql +++ b/mysql/_dumps/structs.sql @@ -47,10 +47,11 @@ CREATE TABLE data_types ( _float64 DECIMAL(10,6), _bool TINYINT(1), _string text, - _date DATETIME NOT NULL, + _date TIMESTAMP NULL, _nildate DATETIME NULL, _ptrdate DATETIME NULL, - _time TIME NOT NULL + _defaultdate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + _time BIGINT UNSIGNED NOT NULL ); DROP TABLE IF EXISTS stats_test; diff --git a/mysql/_example/main.go b/mysql/_example/main.go index fa4d20f651f9f24803f170596e1137a5f7615c7e..f2248a65c4e35397f365fa202efd09aa16b12d08 100644 --- a/mysql/_example/main.go +++ b/mysql/_example/main.go @@ -5,8 +5,8 @@ import ( "log" "time" - "upper.io/db" // Imports the main db package. - _ "upper.io/db/mysql" // Improts the mysql adapter. + "upper.io/v2/db" // Imports the main db package. + _ "upper.io/v2/db/mysql" // Improts the mysql adapter. ) var settings = db.Settings{ diff --git a/mysql/collection.go b/mysql/collection.go index 92ef8220447ef2f205c903af62aeedbfc3a306df..383bbb9fd62e24deefba91dee31218101b2cced4 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -23,237 +23,57 @@ package mysql import ( "database/sql" - "fmt" - "reflect" "strings" - "time" - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source - names []string -} - -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args + *database } -func interfaceArgs(value interface{}) (args []interface{}) { - - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) +var _ = db.Collection(&table{}) - for i = 0; i < total; i++ { - args[i] = toInternal(v.Index(i).Interface()) - } - - return args - } - return nil - default: - args = []interface{}{toInternal(value)} - } - - return args +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} - -func (c *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: c, - where: where, - arguments: arguments, - } - - return result -} - -func (c *table) tableN(i int) string { - if len(c.names) > i { - chunks := strings.SplitN(c.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. -func (c *table) Truncate() error { - - _, err := c.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{c.tableN(0)}, +// Truncate deletes all rows from the table. +func (t *table) Truncate() error { + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. -func (c *table) Append(item interface{}) (interface{}, error) { - +// Append inserts an item (map or struct) into the collection. +func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - cols, vals, err := c.FieldValues(item, toInternal) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = c.source.getPrimaryKey(c.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -261,14 +81,14 @@ func (c *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{c.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = c.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -295,10 +115,10 @@ func (c *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -321,72 +141,15 @@ func (c *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. -func (c *table) Exists() bool { - if err := c.source.tableExists(c.names...); err != nil { +// Exists returns true if the collection exists. +func (t *table) Exists() bool { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } -func (c *table) Name() string { - return strings.Join(c.names, `, `) -} - -// Converts a Go value into internal database representation. -func toInternal(val interface{}) interface{} { - switch t := val.(type) { - case db.Marshaler: - return t - case []byte: - return string(t) - case *time.Time: - if t == nil || t.IsZero() { - return sqlgen.Value{sqlgen.Raw{mysqlNull}} - } - return t.Format(DateFormat) - case time.Time: - if t.IsZero() { - return sqlgen.Value{sqlgen.Raw{mysqlNull}} - } - return t.Format(DateFormat) - case time.Duration: - return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond) - case sql.NullBool: - if t.Valid { - if t.Bool { - return toInternal(t.Bool) - } - return false - } - return sqlgen.Value{sqlgen.Raw{mysqlNull}} - case sql.NullFloat64: - if t.Valid { - if t.Float64 != 0.0 { - return toInternal(t.Float64) - } - return float64(0) - } - return sqlgen.Value{sqlgen.Raw{mysqlNull}} - case sql.NullInt64: - if t.Valid { - if t.Int64 != 0 { - return toInternal(t.Int64) - } - return 0 - } - return sqlgen.Value{sqlgen.Raw{mysqlNull}} - case sql.NullString: - if t.Valid { - return toInternal(t.String) - } - return sqlgen.Value{sqlgen.Raw{mysqlNull}} - case bool: - if t == true { - return `1` - } - return `0` - } - - return to.String(val) +// Name returns the name of the table or tables that form the collection. +func (t *table) Name() string { + return strings.Join(t.Tables, `, `) } diff --git a/mysql/connection.go b/mysql/connection.go index faa3b4cfa8d925a5fcbb70aa23011c5ecdb67b29..68d583c5b8d23ec8945d0fb430a4361c961acaa5 100644 --- a/mysql/connection.go +++ b/mysql/connection.go @@ -26,7 +26,7 @@ import ( "net/url" "strings" - "upper.io/db" + "upper.io/v2/db" ) const defaultPort = 3306 diff --git a/mysql/connection_test.go b/mysql/connection_test.go index 1e833d1e674eb5f039f405df06e52e49416db8ef..f700c9e6f1058f9371d6cbf197dd0209229931b1 100644 --- a/mysql/connection_test.go +++ b/mysql/connection_test.go @@ -24,7 +24,7 @@ package mysql import ( "testing" - "upper.io/db" + "upper.io/v2/db" ) func TestConnectionURL(t *testing.T) { diff --git a/mysql/database.go b/mysql/database.go index 879ae94cc5f5560c105fad00b1a74057fefdf61f..8bcb7ebb646f0614358950b382095c7ccd4690da 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -23,137 +23,284 @@ package mysql import ( "database/sql" - "os" "strings" "time" - // Importing MySQL driver. - _ "github.com/go-sql-driver/mysql" - "upper.io/cache" - "upper.io/db" - "upper.io/db/util/schema" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -const ( - // Adapter is the public name of the adapter. - Adapter = `mysql` -) -var ( - // DateFormat defines the format used for storing dates. - DateFormat = "2006-01-02 15:04:05.000" - // TimeFormat defines the format used for storing time values. - TimeFormat = "%d:%02d:%02d.%03d" + _ "github.com/go-sql-driver/mysql" // MySQL driver. + "github.com/jmoiron/sqlx" + "upper.io/v2/db" + "upper.io/v2/db/util/schema" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/tx" ) -var template *sqlgen.Template - var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL - session *sql.DB - tx *tx + session *sqlx.DB + tx *sqltx.Tx schema *schema.DatabaseSchema } +type tx struct { + *sqltx.Tx + *database +} + +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) + type columnSchemaT struct { Name string `db:"column_name"` } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session +} + +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error + + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + User: settings.User, + Password: settings.Password, + Database: settings.Database, + Options: map[string]string{ + "charset": settings.Charset, + }, + } + + // Connection charset, UTF-8 by default. + if conn.Options["charset"] == "" { + conn.Options["charset"] = "utf8" + } + + if settings.Socket != "" { + conn.Address = db.Socket(settings.Socket) + } else { + if settings.Host == "" { + settings.Host = "127.0.0.1" + } + if settings.Port == 0 { + settings.Port = defaultPort + } + conn.Address = db.HostPort(settings.Host, uint(settings.Port)) + } + + // Replace original d.connURL + d.connURL = conn + } + + if d.session, err = sqlx.Open(`mysql`, d.connURL.String()); err != nil { + return err + } + + d.session.Mapper = sqlutil.NewMapper() + + if err = d.populateSchema(); err != nil { + return err } - return false + + return nil +} + +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +func (d *database) clone() (*database, error) { + src := &database{} + src.Setup(d.connURL) + + if err := src.Open(); err != nil { + return nil, err } + + return src, nil } -func init() { - - template = &sqlgen.Template{ - mysqlColumnSeparator, - mysqlIdentifierSeparator, - mysqlIdentifierQuote, - mysqlValueSeparator, - mysqlValueQuote, - mysqlAndKeyword, - mysqlOrKeyword, - mysqlNotKeyword, - mysqlDescKeyword, - mysqlAscKeyword, - mysqlDefaultOperator, - mysqlClauseGroup, - mysqlClauseOperator, - mysqlColumnValue, - mysqlTableAliasLayout, - mysqlColumnAliasLayout, - mysqlSortByColumnLayout, - mysqlWhereLayout, - mysqlOrderByLayout, - mysqlInsertLayout, - mysqlSelectLayout, - mysqlUpdateLayout, - mysqlDeleteLayout, - mysqlTruncateLayout, - mysqlDropDatabaseLayout, - mysqlDropTableLayout, - mysqlSelectCountLayout, - mysqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -func (s *source) populateSchema() (err error) { - var collections []string +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil +} - s.schema = schema.NewDatabaseSchema() +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error - // Get database name. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {sqlgen.Raw{`DATABASE()`}}, - }, + if len(names) == 0 { + return nil, db.ErrMissingCollectionName } - var row *sql.Row + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } - if row, err = s.doQueryRow(stmt); err != nil { - return err + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper + + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) + + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } + + tableName := chunks[0] + + if err := d.tableExists(tableName); err != nil { + return nil, err + } + + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } } - if err = row.Scan(&s.schema.Name); err != nil { - return err + return col, nil +} + +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { + + tablesInSchema := len(d.schema.Tables) + + // Is schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil } - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { - return err + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Table: sqlgen.TableWithName(`information_schema.tables`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err + // Executing statement. + var rows *sqlx.Rows + if rows, err = d.Query(stmt, d.schema.Name); err != nil { + return nil, err + } + + defer rows.Close() + + collections = []string{} + + var name string + + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err } + + // Adding table entry to schema. + d.schema.AddTable(name) + + // Adding table to collections array. + collections = append(collections, name) + } + + return collections, nil +} + +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL + + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err } + conn.Database = database + + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) + return err } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name +} + +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err + } + + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err + } + + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil +} + +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -163,26 +310,27 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - res, err = s.session.Exec(query, args...) + res, err = d.session.Exec(query, args...) } return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { - var rows *sql.Rows +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows var query string var err error var start, end int64 @@ -191,27 +339,28 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - rows, err = s.tx.sqlTx.Query(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - rows, err = s.session.Query(query, args...) + rows, err = d.session.Queryx(query, args...) } return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string - var row *sql.Row + var row *sqlx.Row var err error var start, end int64 @@ -219,248 +368,97 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - row = s.tx.sqlTx.QueryRow(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - row = s.session.QueryRow(query, args...) + row = d.session.QueryRowx(query, args...) } return row, err } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} - -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) clone() (*source, error) { - src := &source{} - src.Setup(s.connURL) - - if err := src.Open(); err != nil { - return nil, err - } - - return src, nil -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sql.Tx + d.schema = schema.NewDatabaseSchema() - if sqlTx, err = s.session.Begin(); err != nil { - return nil, err - } - - if clone, err = s.clone(); err != nil { - return nil, err + // Get database name. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.RawValue(`DATABASE()`), + ), } - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx + var row *sqlx.Row - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { - s.connURL = connURL - return s.Open() -} - -// Returns the underlying *sql.DB instance. -func (s *source) Driver() interface{} { - return s.session -} - -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { - var err error - - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - User: settings.User, - Password: settings.Password, - Database: settings.Database, - Options: map[string]string{ - "charset": settings.Charset, - }, - } - - // Connection charset, UTF-8 by default. - if conn.Options["charset"] == "" { - conn.Options["charset"] = "utf8" - } - - if settings.Socket != "" { - conn.Address = db.Socket(settings.Socket) - } else { - if settings.Host == "" { - settings.Host = "127.0.0.1" - } - if settings.Port == 0 { - settings.Port = defaultPort - } - conn.Address = db.HostPort(settings.Host, uint(settings.Port)) - } - - // Replace original s.connURL - s.connURL = conn - } - - if s.session, err = sql.Open(`mysql`, s.connURL.String()); err != nil { + if row, err = d.QueryRow(stmt); err != nil { return err } - if err = s.populateSchema(); err != nil { + if err = row.Scan(&d.schema.Name); err != nil { return err } - return nil -} - -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} - -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL - - if conn, err = ParseURL(s.connURL.String()); err != nil { + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { return err } - conn.Database = database - - s.connURL = conn - - return s.Open() -} - -// Drops the currently active database. -func (s *source) Drop() error { - - _, err := s.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlDropDatabase, - Database: sqlgen.Database{s.schema.Name}, - }) - - return err -} - -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { - - tablesInSchema := len(s.schema.Tables) - - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil - } - - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Table: sqlgen.Table{ - `information_schema.tables`, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_schema`}, - `=`, - sqlPlaceholder, - }, - }, - } - - // Executing statement. - var rows *sql.Rows - if rows, err = s.doQuery(stmt, s.schema.Name); err != nil { - return nil, err - } - - defer rows.Close() - - collections = []string{} - - var name string - - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err } - - // Adding table entry to schema. - s.schema.AddTable(name) - - // Adding table to collections array. - collections = append(collections, name) } - return collections, nil + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error - var rows *sql.Rows + var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.tables`}, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.tables`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -474,32 +472,40 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.columns`}, - Columns: sqlgen.Columns{ - {`column_name`}, - {`data_type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.columns`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`column_name`), + sqlgen.ColumnWithName(`data_type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - var rows *sql.Rows + var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -509,96 +515,64 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) - } - - return s.schema.TableInfo[tableName].Columns, nil -} - -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } - - col := &table{ - source: s, - names: names, - } - - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) - - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName - } - - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } - - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err - } + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return col, nil + return d.schema.TableInfo[tableName].Columns, nil } -// getPrimaryKey returns the names of the columns that define the primary key -// of the table. -func (s *source) getPrimaryKey(tableName string) ([]string, error) { +func (d *database) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { return tableSchema.PrimaryKey, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{ - sqlgen.Raw{` + Type: sqlgen.Select, + Table: sqlgen.RawValue(` information_schema.table_constraints AS t JOIN information_schema.key_column_usage k USING(constraint_name, table_schema, table_name) - `}, - }, - Columns: sqlgen.Columns{ - {`k.column_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`t.constraint_type`}, `=`, sqlgen.Value{`primary key`}}, - sqlgen.ColumnValue{sqlgen.Column{`t.table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`t.table_name`}, `=`, sqlPlaceholder}, - }, - OrderBy: sqlgen.OrderBy{ - sqlgen.SortColumns{ - { - sqlgen.Column{`k.ordinal_position`}, - sqlgen.SqlSortAsc, - }, + `), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`k.column_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.constraint_type`), + Operator: `=`, + Value: sqlgen.NewValue(`primary key`), + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.table_schema`), + Operator: `=`, + Value: sqlPlaceholder, }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), + OrderBy: &sqlgen.OrderBy{ + SortColumns: sqlgen.JoinSortColumns( + &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(`k.ordinal_position`), + Order: sqlgen.Ascendent, + }, + ), }, } - var rows *sql.Rows + var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } diff --git a/mysql/database_test.go b/mysql/database_test.go index 60e5febb7793458a43b1ccc231910e07c7c6a678..f3c9cc78eaedf6e50b6f6f853e0354bf5d8fcc5e 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -21,17 +21,10 @@ package mysql -// In order to execute these tests you must initialize the database first: -// -// cd _dumps -// make -// cd .. -// go test - import ( "database/sql" "errors" - "flag" + "fmt" "math/rand" "os" "reflect" @@ -40,49 +33,61 @@ import ( "testing" "time" + "github.com/jmoiron/sqlx" "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util/sqlutil" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlutil" ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" +) + +const ( + testTimeZone = "Canada/Eastern" ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, + Options: map[string]string{ + // See https://github.com/go-sql-driver/mysql/issues/9 + "parseTime": "true", + // Might require you to use mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root -p mysql + "time_zone": fmt.Sprintf(`"%s"`, testTimeZone), + }, } -var host = flag.String("host", "testserver.local", "Testing server address.") +var host string // Structure for testing conversions and datatypes. type testValuesStruct struct { - Uint uint `field:"_uint"` - Uint8 uint8 `field:"_uint8"` - Uint16 uint16 `field:"_uint16"` - Uint32 uint32 `field:"_uint32"` - Uint64 uint64 `field:"_uint64"` - - Int int `field:"_int"` - Int8 int8 `field:"_int8"` - Int16 int16 `field:"_int16"` - Int32 int32 `field:"_int32"` - Int64 int64 `field:"_int64"` - - Float32 float32 `field:"_float32"` - Float64 float64 `field:"_float64"` - - Bool bool `field:"_bool"` - String string `field:"_string"` - - Date time.Time `field:"_date"` - DateN *time.Time `field:"_nildate"` - DateP *time.Time `field:"_ptrdate"` - Time time.Duration `field:"_time"` + Uint uint `db:"_uint"` + Uint8 uint8 `db:"_uint8"` + Uint16 uint16 `db:"_uint16"` + Uint32 uint32 `db:"_uint32"` + Uint64 uint64 `db:"_uint64"` + + Int int `db:"_int"` + Int8 int8 `db:"_int8"` + Int16 int16 `db:"_int16"` + Int32 int32 `db:"_int32"` + Int64 int64 `db:"_int64"` + + Float32 float32 `db:"_float32"` + Float64 float64 `db:"_float64"` + + Bool bool `db:"_bool"` + String string `db:"_string"` + + Date time.Time `db:"_date"` + DateN *time.Time `db:"_nildate"` + DateP *time.Time `db:"_ptrdate"` + DateD *time.Time `db:"_defaultdate,omitempty"` + Time int64 `db:"_time"` } type artistWithInt64Key struct { @@ -121,7 +126,14 @@ func (item *itemWithKey) SetID(keys map[string]interface{}) error { var testValues testValuesStruct func init() { - t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) + loc, err := time.LoadLocation(testTimeZone) + + if err != nil { + panic(err.Error()) + } + + t := time.Date(2011, 7, 28, 1, 2, 3, 0, loc) + tnz := time.Date(2012, 7, 28, 1, 2, 3, 0, time.UTC) testValues = testValuesStruct{ 1, 1, 1, 1, 1, @@ -131,12 +143,16 @@ func init() { "Hello world!", t, nil, - &t, - time.Second * time.Duration(7331), + &tnz, + nil, + int64(time.Second * 1337), } - flag.Parse() - settings.Address = db.ParseAddress(*host) + if host = os.Getenv("TEST_HOST"); host == "" { + host = "localhost" + } + + settings.Address = db.ParseAddress(host) } // Loggin some information to stdout (like the SQL query and its @@ -163,8 +179,8 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, - Host: *host, + Database: databaseName, + Host: host, User: username, Password: password, } @@ -177,8 +193,8 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, - Host: *host, + Database: databaseName, + Host: host, User: username, Password: "fail", } @@ -190,7 +206,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong database. wrongSettings = db.Settings{ Database: "fail", - Host: *host, + Host: host, User: username, Password: password, } @@ -201,8 +217,8 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, - Host: *host, + Database: databaseName, + Host: host, User: "fail", Password: password, } @@ -218,10 +234,10 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, - Host: *host, + Host: host, } // Opening database. @@ -599,39 +615,10 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an struct with no tags. - rowStruct := struct { - ID uint64 - Name string - }{} - - res = artist.Find() - - for { - err = res.Next(&rowStruct) - - if err == db.ErrNoMoreRows { - break - } - - if err == nil { - if rowStruct.ID == 0 { - t.Fatalf("Expecting a not null ID.") - } - if rowStruct.Name == "" { - t.Fatalf("Expecting a name.") - } - } else { - t.Fatal(err) - } - } - - res.Close() - // Dumping into a tagged struct. rowStruct2 := struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find() @@ -657,7 +644,7 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an slice of maps. + // Dumping into a slice of maps. allRowsMap := []map[string]interface{}{} res = artist.Find() @@ -676,10 +663,9 @@ func TestResultFetch(t *testing.T) { } // Dumping into an slice of structs. - allRowsStruct := []struct { - ID uint64 - Name string + ID uint64 `db:"id"` + Name string `db:"name"` }{} res = artist.Find() @@ -699,8 +685,8 @@ func TestResultFetch(t *testing.T) { // Dumping into an slice of tagged structs. allRowsStruct2 := []struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find() @@ -738,8 +724,8 @@ func TestUpdate(t *testing.T) { // Defining destination struct value := struct { - ID uint64 - Name string + ID uint64 `db:"id"` + Name string `db:"name"` }{} // Getting the first artist. @@ -770,7 +756,7 @@ func TestUpdate(t *testing.T) { // Updating set with a struct rowStruct := struct { - Name string + Name string `db:"name"` }{strings.ToLower(value.Name)} if err = res.Update(rowStruct); err != nil { @@ -1103,9 +1089,9 @@ func TestRawRelations(t *testing.T) { func TestRawQuery(t *testing.T) { var sess db.Database - var rows *sql.Rows + var rows *sqlx.Rows var err error - var drv *sql.DB + var drv *sqlx.DB type publicationType struct { ID int64 `db:"id,omitempty"` @@ -1119,9 +1105,9 @@ func TestRawQuery(t *testing.T) { defer sess.Close() - drv = sess.Driver().(*sql.DB) + drv = sess.Driver().(*sqlx.DB) - rows, err = drv.Query(` + rows, err = drv.Queryx(` SELECT p.id, p.title AS publication_title, @@ -1416,10 +1402,24 @@ func TestDataTypes(t *testing.T) { // Trying to dump the subject into an empty structure of the same type. var item testValuesStruct - res.One(&item) + if err = res.One(&item); err != nil { + t.Fatal(err) + } + + if item.DateD == nil { + t.Fatal("Expecting default date to have been set on append") + } + + // Copy the default date (this value is set by the database) + testValues.DateD = item.DateD + + loc, _ := time.LoadLocation(testTimeZone) + item.Date = item.Date.In(loc) // The original value and the test subject must match. if reflect.DeepEqual(item, testValues) == false { + fmt.Printf("item1: %v\n", item) + fmt.Printf("test2: %v\n", testValues) t.Fatalf("Struct is different.") } } @@ -1440,7 +1440,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec("TRUNCATE TABLE `artist`"); err != nil { b.Fatal(err) @@ -1494,7 +1494,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/mysql/mysql.go b/mysql/mysql.go new file mode 100644 index 0000000000000000000000000000000000000000..9d48672a576bd9a4138bf5661f37bcfa300677ec --- /dev/null +++ b/mysql/mysql.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package mysql // import "upper.io/v2/db/mysql" + +import ( + "upper.io/cache" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `mysql` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/mysql/result.go b/mysql/result.go deleted file mode 100644 index 6861eef8deb185128dec4db5af630a92bf61245d..0000000000000000000000000000000000000000 --- a/mysql/result.go +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package mysql - -import ( - "database/sql" - "fmt" - "strings" - - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"_t"` -} - -type result struct { - table *table - cursor *sql.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) error { - - var err error - - // Current cursor. - err = r.setCursor() - - if err != nil { - r.Close() - } - - // Fetching the next result from the cursor. - err = sqlutil.FetchRow(r.cursor, dst) - - if err != nil { - r.Close() - } - - return err -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values, toInternal) - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() error { - var err error - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts matching elements. -func (r *result) Count() (uint64, error) { - var count counter - - rows, err := r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - defer rows.Close() - if err = sqlutil.FetchRow(rows, &count); err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/mysql/layout.go b/mysql/template.go similarity index 63% rename from mysql/layout.go rename to mysql/template.go index c0ef34060f173ded0bbe07bc95f3c03611f6cf8f..1b2d21dab87f67e0f0cc87ed872d9ba9f37b1112 100644 --- a/mysql/layout.go +++ b/mysql/template.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,37 +22,38 @@ package mysql const ( - mysqlColumnSeparator = `.` - mysqlIdentifierSeparator = `, ` - mysqlIdentifierQuote = "`{{.Raw}}`" - mysqlValueSeparator = `, ` - mysqlValueQuote = `'{{.}}'` - mysqlAndKeyword = `AND` - mysqlOrKeyword = `OR` - mysqlNotKeyword = `NOT` - mysqlDescKeyword = `DESC` - mysqlAscKeyword = `ASC` - mysqlDefaultOperator = `=` - mysqlClauseGroup = `({{.}})` - mysqlClauseOperator = ` {{.}} ` - mysqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - mysqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - mysqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - mysqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - mysqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = "`{{.Value}}`" + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - mysqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - mysqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -79,19 +80,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - mysqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - mysqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - mysqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -106,7 +107,7 @@ const ( {{end}} ` - mysqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -114,23 +115,21 @@ const ( {{.Extra}} ` - mysqlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} ` - mysqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - mysqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - mysqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - mysqlNull = `NULL` ) diff --git a/postgresql/README.md b/postgresql/README.md index f374a92b187845718474af6de4c985a6f953b10a..2c1e720130338f7d83ad14bdb169f1f338676d58 100644 --- a/postgresql/README.md +++ b/postgresql/README.md @@ -1,6 +1,6 @@ -# PostgreSQL adapter for upper.io/db +# PostgreSQL adapter for upper.io/v2/db Please read the full docs, acknowledgements and examples at -[https://upper.io/db/postgresql][1] +[https://upper.io/v2/db/postgresql][1] -[1]: https://upper.io/db/postgresql +[1]: https://upper.io/v2/db/postgresql diff --git a/postgresql/_dumps/Makefile b/postgresql/_dumps/Makefile index 675048196a990e22c46df5ac45a4eac2ea0c9818..83de0013912ec7f3ef76d69fd8bb4fd45f5b5a54 100644 --- a/postgresql/_dumps/Makefile +++ b/postgresql/_dumps/Makefile @@ -1,2 +1,4 @@ +TEST_HOST ?= 127.0.0.1 + all: - cat structs.sql | PGPASSWORD="upperio" psql -Uupperio upperio_tests -htestserver.local + cat structs.sql | PGPASSWORD="upperio" psql -Uupperio upperio_tests -h$(TEST_HOST) diff --git a/postgresql/_dumps/structs.sql b/postgresql/_dumps/structs.sql index 2298f5cf26ddd8428493b59f77a79672f1e280c5..25a210f573e6ba623a789e25db6cda8da9b7fd50 100644 --- a/postgresql/_dumps/structs.sql +++ b/postgresql/_dumps/structs.sql @@ -41,10 +41,11 @@ CREATE TABLE data_types ( _float64 numeric(10,6), _bool boolean, _string text, - _date timestamp without time zone, + _date timestamp with time zone, _nildate timestamp without time zone null, _ptrdate timestamp without time zone, - _time time without time zone + _defaultdate timestamp without time zone DEFAULT now(), + _time bigint ); DROP TABLE IF EXISTS stats_test; @@ -63,3 +64,12 @@ CREATE TABLE composite_keys ( some_val varchar(255) default '', primary key (code, user_id) ); + +DROP TABLE IF EXISTS option_types; + +CREATE TABLE option_types ( + id serial primary key, + name varchar(255) default '', + tags varchar(64)[], + settings jsonb +); diff --git a/postgresql/_example/main.go b/postgresql/_example/main.go index 6377714527ccd11eaf5c562b95b1f721c30161cd..a1e916065e120ae0875f389ae40da445ab7f7c23 100644 --- a/postgresql/_example/main.go +++ b/postgresql/_example/main.go @@ -5,8 +5,8 @@ import ( "log" "time" - "upper.io/db" // Imports the main db package. - _ "upper.io/db/postgresql" // Imports the postgresql adapter. + "upper.io/v2/db" // Imports the main db package. + _ "upper.io/v2/db/postgresql" // Imports the postgresql adapter. ) var settings = db.Settings{ diff --git a/postgresql/collection.go b/postgresql/collection.go index f6bd801ca67b876d594e2cf3dfca330c99193f96..5f10d9f7f16748ce5bb794436ceca621d4b10f52 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -24,194 +24,34 @@ package postgresql import ( "database/sql" "fmt" - "reflect" "strings" - "time" - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" + "github.com/jmoiron/sqlx" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source + *database primaryKey string - names []string } -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = toInternal(v.Index(i).Interface()) - } - - return args - } - return nil - default: - args = []interface{}{toInternal(value)} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} +var _ = db.Collection(&table{}) +// Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: t, - where: where, - arguments: arguments, - } - - return result + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - - _, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{t.tableN(0)}, + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { @@ -221,41 +61,24 @@ func (t *table) Truncate() error { return nil } -// Appends an item (map or struct) into the collection. +// Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - //var id []interface{} - - cols, vals, err := t.FieldValues(item, toInternal) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = t.source.getPrimaryKey(t.tableN(0)); err != nil { + var pKey []string + + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -263,17 +86,17 @@ func (t *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{t.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } // No primary keys defined. if len(pKey) == 0 { var res sql.Result - if res, err = t.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -284,18 +107,20 @@ func (t *table) Append(item interface{}) (interface{}, error) { return lastID, nil } - var rows *sql.Rows + var rows *sqlx.Rows // A primary key was found. stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) - if rows, err = t.source.doQuery(stmt, arguments...); err != nil { + if rows, err = t.database.Query(stmt, sqlgenArgs...); err != nil { return nil, err } defer rows.Close() - var keyMap map[string]interface{} - err = sqlutil.FetchRow(rows, &keyMap) + keyMap := map[string]interface{}{} + if err := sqlutil.FetchRow(rows, &keyMap); err != nil { + return nil, err + } // Does the item satisfy the db.IDSetter interface? if setter, ok := item.(db.IDSetter); ok { @@ -305,15 +130,13 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, nil } - // The IDSetter interface does not match, we'll be looking for another - // interface match. + // The IDSetter interface does not match, look for another interface match. if len(keyMap) == 1 { - id := keyMap[pKey[0]] // Matches db.Int64IDSetter if setter, ok := item.(db.Int64IDSetter); ok { - if err = setter.SetID(to.Int64(id)); err != nil { + if err = setter.SetID(id.(int64)); err != nil { return nil, err } return nil, nil @@ -321,86 +144,29 @@ func (t *table) Append(item interface{}) (interface{}, error) { // Matches db.Uint64IDSetter if setter, ok := item.(db.Uint64IDSetter); ok { - if err = setter.SetID(to.Uint64(id)); err != nil { + if err = setter.SetID(uint64(id.(int64))); err != nil { return nil, err } return nil, nil } // No interface matched, falling back to old behaviour. - return to.Int64(id), nil + return id.(int64), nil } // More than one key, no interface matched, let's return a map. return keyMap, nil } -// Returns true if the collection exists. +// Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) -} - -// Converts a Go value into internal database representation. -func toInternal(val interface{}) interface{} { - switch v := val.(type) { - case db.Marshaler: - return v - case []byte: - return string(v) - case *time.Time: - if v == nil || v.IsZero() { - return sqlgen.Value{sqlgen.Raw{psqlNull}} - } - return v.Format(DateFormat) - case time.Time: - if v.IsZero() { - return sqlgen.Value{sqlgen.Raw{psqlNull}} - } - return v.Format(DateFormat) - case time.Duration: - return fmt.Sprintf(TimeFormat, int(v/time.Hour), int(v/time.Minute%60), int(v/time.Second%60), v%time.Second/time.Millisecond) - case sql.NullBool: - if v.Valid { - if v.Bool { - return toInternal(v.Bool) - } - return false - } - return sqlgen.Value{sqlgen.Raw{psqlNull}} - case sql.NullFloat64: - if v.Valid { - if v.Float64 != 0.0 { - return toInternal(v.Float64) - } - return float64(0) - } - return sqlgen.Value{sqlgen.Raw{psqlNull}} - case sql.NullInt64: - if v.Valid { - if v.Int64 != 0 { - return toInternal(v.Int64) - } - return 0 - } - return sqlgen.Value{sqlgen.Raw{psqlNull}} - case sql.NullString: - if v.Valid { - return toInternal(v.String) - } - return sqlgen.Value{sqlgen.Raw{psqlNull}} - case bool: - if v { - return `1` - } - return `0` - } - - return to.String(val) + return strings.Join(t.Tables, `, `) } diff --git a/postgresql/connection.go b/postgresql/connection.go index 91b6bd3641ce796e3b85d24e0ce800f3b246dd5f..4758a902ae6acfe43da422e7f5344252600bcc33 100644 --- a/postgresql/connection.go +++ b/postgresql/connection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -27,8 +27,8 @@ import ( "strings" "unicode" - "github.com/xiam/gopostgresql" - "upper.io/db" + "github.com/lib/pq" + "upper.io/v2/db" ) // scanner implements a tokenizer for libpq-style option strings. @@ -37,8 +37,8 @@ type scanner struct { i int } -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. +// Next returns the next rune. It returns 0, false if the end of the text has +// been reached. func (s *scanner) Next() (rune, bool) { if s.i >= len(s.s) { return 0, false @@ -48,8 +48,8 @@ func (s *scanner) Next() (rune, bool) { return r, true } -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. +// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the +// end of the text has been reached. func (s *scanner) SkipSpaces() (rune, bool) { r, ok := s.Next() for unicode.IsSpace(r) && ok { @@ -91,7 +91,6 @@ type ConnectionURL struct { var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) func (c ConnectionURL) String() (s string) { - u := make([]string, 0, 6) // TODO: This surely needs some sort of escaping. @@ -127,7 +126,7 @@ func (c ConnectionURL) String() (s string) { c.Options = map[string]string{} } - // If not present, SSL mode is asumed disabled. + // If not present, SSL mode is assumed disabled. if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" { c.Options["sslmode"] = "disable" } @@ -168,8 +167,15 @@ func ParseURL(s string) (u ConnectionURL, err error) { u.Database = o.Get("dbname") - u.Options = map[string]string{ - "sslmode": o.Get("sslmode"), + u.Options = make(map[string]string) + + for k := range o { + switch k { + case "user", "password", "host", "port", "dbname": + // Skip + default: + u.Options[k] = o[k] + } } return u, err diff --git a/postgresql/connection_test.go b/postgresql/connection_test.go index 61b65405a87895686308d25ef27f08a2b0b58986..2600b613081af81d50cfeee78531914bcbc5ed24 100644 --- a/postgresql/connection_test.go +++ b/postgresql/connection_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -24,11 +24,10 @@ package postgresql import ( "testing" - "upper.io/db" + "upper.io/v2/db" ) func TestConnectionURL(t *testing.T) { - c := ConnectionURL{} // Default connection string is empty. @@ -79,7 +78,6 @@ func TestConnectionURL(t *testing.T) { if c.String() != `user=Anakin password=Some\ Sort\ of\ \'\ Password host=localhost port=1234 dbname=MyDatabase sslmode=verify-full` { t.Fatal(`Test failed, got:`, c.String()) } - } func TestParseConnectionURL(t *testing.T) { @@ -159,4 +157,21 @@ func TestParseConnectionURL(t *testing.T) { t.Fatal("Failed to parse SSLMode.") } + s = "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if len(u.Options) != 2 { + t.Fatal("Expecting exactly two options.") + } + + if u.Options["sslmode"] != "verify-full" { + t.Fatal("Failed to parse SSLMode.") + } + + if u.Options["timezone"] != "UTC" { + t.Fatal("Failed to parse timezone.") + } } diff --git a/postgresql/database.go b/postgresql/database.go index af7a861de579f826ffbe6aa2f1e3d70fbd81c70b..4a03989b579d2c23fb0f0e5f0df61e9427e62527 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -24,142 +24,268 @@ package postgresql import ( "database/sql" "fmt" - "os" "strconv" "strings" "time" - // Importing PostgreSQL driver. - _ "github.com/xiam/gopostgresql" - "upper.io/cache" - "upper.io/db" - "upper.io/db/util/schema" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -const ( - // Adapter is the public name of the adapter. - Adapter = `postgresql` + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" // PostgreSQL driver. + "upper.io/v2/db" + "upper.io/v2/db/util/schema" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/tx" ) var ( - // DateFormat defines the format used for storing dates. - DateFormat = "2006-01-02 15:04:05.999999999 MST" - // TimeFormat defines the format used for storing time values. - TimeFormat = "%d:%02d:%02d.%d" - // SSLMode defined wheter to enable or disable SSL connections to PostgreSQL - // server (deprecated). - SSLMode = false + sqlPlaceholder = sqlgen.RawValue(`?`) ) -var template *sqlgen.Template +type database struct { + connURL db.ConnectionURL + session *sqlx.DB + tx *sqltx.Tx + schema *schema.DatabaseSchema +} + +type tx struct { + *sqltx.Tx + *database +} var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) ) -type source struct { - connURL db.ConnectionURL - session *sql.DB - tx *tx - schema *schema.DatabaseSchema +type columnSchemaT struct { + Name string `db:"column_name"` + DataType string `db:"data_type"` } -type columnSchemaT struct { - Name string `db:"column_name"` +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error + + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + + conn := ConnectionURL{ + User: settings.User, + Password: settings.Password, + Address: db.HostPort(settings.Host, uint(settings.Port)), + Database: settings.Database, + Options: map[string]string{ + "sslmode": "disable", + }, + } + + d.connURL = conn + } + + if d.session, err = sqlx.Open(`postgres`, d.connURL.String()); err != nil { + return err + } + + d.session.Mapper = sqlutil.NewMapper() + + if err = d.populateSchema(); err != nil { + return err } - return false + + return nil +} + +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +func (d *database) clone() (*database, error) { + src := new(database) + src.Setup(d.connURL) + + if err := src.Open(); err != nil { + return nil, err } + + return src, nil } -func init() { - - template = &sqlgen.Template{ - pgsqlColumnSeparator, - pgsqlIdentifierSeparator, - pgsqlIdentifierQuote, - pgsqlValueSeparator, - pgsqlValueQuote, - pgsqlAndKeyword, - pgsqlOrKeyword, - pgsqlNotKeyword, - pgsqlDescKeyword, - pgsqlAscKeyword, - pgsqlDefaultOperator, - pgsqlClauseGroup, - pgsqlClauseOperator, - pgsqlColumnValue, - pgsqlTableAliasLayout, - pgsqlColumnAliasLayout, - pgsqlSortByColumnLayout, - pgsqlWhereLayout, - pgsqlOrderByLayout, - pgsqlInsertLayout, - pgsqlSelectLayout, - pgsqlUpdateLayout, - pgsqlDeleteLayout, - pgsqlTruncateLayout, - pgsqlDropDatabaseLayout, - pgsqlDropTableLayout, - pgsqlSelectCountLayout, - pgsqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -func (s *source) populateSchema() (err error) { - var collections []string +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil +} - s.schema = schema.NewDatabaseSchema() +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error - // Get database name. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {sqlgen.Raw{`CURRENT_DATABASE()`}}, - }, + if len(names) == 0 { + return nil, db.ErrMissingCollectionName + } + + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } } - var row *sql.Row + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper - if row, err = s.doQueryRow(stmt); err != nil { - return err + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) + + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } + + tableName := chunks[0] + + if err := d.tableExists(tableName); err != nil { + return nil, err + } + + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } } - if err = row.Scan(&s.schema.Name); err != nil { - return err + return col, nil +} + +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { + + tablesInSchema := len(d.schema.Tables) + + // Is schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil } - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { - return err + // Schema is empty. + + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Table: sqlgen.TableWithName(`information_schema.tables`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlgen.NewValue(`public`), + }, + ), } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err + // Executing statement. + var rows *sqlx.Rows + if rows, err = d.Query(stmt); err != nil { + return nil, err + } + + defer rows.Close() + + collections = []string{} + + var name string + + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err } + + // Adding table entry to schema. + d.schema.AddTable(name) + + // Adding table to collections array. + collections = append(collections, name) + } + + return collections, nil +} + +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL + + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err } + conn.Database = database + + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) return err } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name +} + +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err + } + + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err + } + + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil +} + +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -169,31 +295,32 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - res, err = s.session.Exec(query, args...) + res, err = d.session.Exec(query, args...) } return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { - var rows *sql.Rows +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows var query string var err error var start, end int64 @@ -202,32 +329,33 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - rows, err = s.tx.sqlTx.Query(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - rows, err = s.session.Query(query, args...) + rows, err = d.session.Queryx(query, args...) } return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string - var row *sql.Row + var row *sqlx.Row var err error var start, end int64 @@ -235,251 +363,106 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, `$`+strconv.Itoa(i+1), 1) } - if s.tx != nil { - row = s.tx.sqlTx.QueryRow(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - row = s.session.QueryRow(query, args...) + row = d.session.QueryRowx(query, args...) } return row, err } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} - -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) clone() (*source, error) { - src := new(source) - src.Setup(s.connURL) - - if err := src.Open(); err != nil { - return nil, err - } - - return src, nil -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} - -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sql.Tx +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - if sqlTx, err = s.session.Begin(); err != nil { - return nil, err - } + d.schema = schema.NewDatabaseSchema() - if clone, err = s.clone(); err != nil { - return nil, err + // Get database name. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.RawValue(`CURRENT_DATABASE()`), + ), } - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { - s.connURL = connURL - return s.Open() -} - -// Returns the underlying *sql.DB instance. -func (s *source) Driver() interface{} { - return s.session -} - -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { - var err error - - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - User: settings.User, - Password: settings.Password, - Address: db.HostPort(settings.Host, uint(settings.Port)), - Database: settings.Database, - Options: map[string]string{ - "sslmode": "disable", - }, - } - - // Testing for SSLMode (deprecated) - if SSLMode { - conn.Options["sslmode"] = "verify-full" - } - - // Replace original s.connURL - s.connURL = conn - } + var row *sqlx.Row - if s.session, err = sql.Open(`postgres`, s.connURL.String()); err != nil { + if row, err = d.QueryRow(stmt); err != nil { return err } - if err = s.populateSchema(); err != nil { + if err = row.Scan(&d.schema.Name); err != nil { return err } - return nil -} - -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} - -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL - - if conn, err = ParseURL(s.connURL.String()); err != nil { + if collections, err = d.Collections(); err != nil { return err } - conn.Database = database - - s.connURL = conn - - return s.Open() -} - -// Drops the currently active database. -func (s *source) Drop() error { - - _, err := s.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlDropDatabase, - Database: sqlgen.Database{s.schema.Name}, - }) - - return err -} - -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { - - tablesInSchema := len(s.schema.Tables) - - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil - } - - // Schema is empty. - - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Table: sqlgen.Table{ - `information_schema.tables`, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_schema`}, - `=`, - sqlgen.Value{`public`}, - }, - }, - } - - // Executing statement. - var rows *sql.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err - } - - defer rows.Close() - - collections = []string{} - - var name string - - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err + for i := range collections { + if _, err = d.Collection(collections[i]); err != nil { + return err } - - // Adding table entry to schema. - s.schema.AddTable(name) - - // Adding table to collections array. - collections = append(collections, name) } - return collections, nil + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error - var rows *sql.Rows + var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.tables`}, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.tables`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_catalog`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { return db.ErrCollectionDoesNotExist } defer rows.Close() - if rows.Next() == false { + if !rows.Next() { return db.ErrCollectionDoesNotExist } } @@ -487,103 +470,62 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{ - `information_schema.columns`, - }, - Columns: sqlgen.Columns{ - {`column_name`}, - {`data_type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_catalog`}, - `=`, - sqlPlaceholder, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.columns`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`column_name`), + sqlgen.ColumnWithName(`data_type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_catalog`), + Operator: `=`, + Value: sqlPlaceholder, }, - sqlgen.ColumnValue{ - sqlgen.Column{`table_name`}, - `=`, - sqlPlaceholder, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, }, - }, + ), } - var rows *sql.Rows + var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } + defer rows.Close() + tableFields := []columnSchemaT{} if err = sqlutil.FetchRows(rows, &tableFields); err != nil { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) - } - - return s.schema.TableInfo[tableName].Columns, nil -} - -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } - - col := &table{ - source: s, - names: names, - } - - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) - - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName - } - - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } - - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err - } + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return col, nil + return d.schema.TableInfo[tableName].Columns, nil } -func (s *source) getPrimaryKey(tableName string) ([]string, error) { - - tableSchema := s.schema.Table(tableName) +func (d *database) getPrimaryKey(tableName string) ([]string, error) { + tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { return tableSchema.PrimaryKey, nil @@ -591,35 +533,37 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { // Getting primary key. See https://github.com/upper/db/issues/24. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`}, - Columns: sqlgen.Columns{ - {`pg_attribute.attname`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`pg_class.oid`}, `=`, sqlgen.Value{sqlgen.Raw{`'"` + tableName + `"'::regclass`}}}, - sqlgen.ColumnValue{sqlgen.Column{`indrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, - sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, - sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attnum`}, `=`, sqlgen.Value{sqlgen.Raw{`any(pg_index.indkey)`}}}, - sqlgen.Raw{`indisprimary`}, - }, - OrderBy: sqlgen.OrderBy{ - sqlgen.SortColumns{ - { - sqlgen.Column{`attname`}, - sqlgen.SqlSortAsc, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`pg_index, pg_class, pg_attribute`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`pg_attribute.attname`), + ), + Where: sqlgen.WhereConditions( + sqlgen.RawValue(`pg_class.oid = '"`+tableName+`"'::regclass`), + sqlgen.RawValue(`indrelid = pg_class.oid`), + sqlgen.RawValue(`pg_attribute.attrelid = pg_class.oid`), + sqlgen.RawValue(`pg_attribute.attnum = ANY(pg_index.indkey)`), + sqlgen.RawValue(`indisprimary`), + ), + OrderBy: &sqlgen.OrderBy{ + SortColumns: sqlgen.JoinSortColumns( + &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(`attname`), + Order: sqlgen.Ascendent, }, - }, + ), }, } - var rows *sql.Rows + var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt); err != nil { + if rows, err = d.Query(stmt); err != nil { return nil, err } + defer rows.Close() + tableSchema.PrimaryKey = make([]string, 0, 1) for rows.Next() { diff --git a/postgresql/database_test.go b/postgresql/database_test.go index a4e911396a03d15e37f2a0fd013c2eaad4b950b6..45f0b81855bd73a6774ffb1f41b699372c0903b0 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -21,18 +21,10 @@ package postgresql -// In order to execute these tests you must initialize the database first: -// -// cd _dumps -// make -// cd .. -// go test - import ( "database/sql" "errors" - "flag" - "math/rand" + "fmt" "os" "reflect" "strconv" @@ -40,49 +32,64 @@ import ( "testing" "time" - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util/sqlutil" + "math/rand" + + "github.com/jmoiron/sqlx" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlutil" ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" +) + +const ( + testTimeZone = "Canada/Eastern" ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, + Options: map[string]string{ + "timezone": testTimeZone, + }, } -var host = flag.String("host", "testserver.local", "Testing server address.") +var host string // Structure for testing conversions and datatypes. type testValuesStruct struct { - Uint uint `field:"_uint"` - Uint8 uint8 `field:"_uint8"` - Uint16 uint16 `field:"_uint16"` - Uint32 uint32 `field:"_uint32"` - Uint64 uint64 `field:"_uint64"` - - Int int `field:"_int"` - Int8 int8 `field:"_int8"` - Int16 int16 `field:"_int16"` - Int32 int32 `field:"_int32"` - Int64 int64 `field:"_int64"` - - Float32 float32 `field:"_float32"` - Float64 float64 `field:"_float64"` - - Bool bool `field:"_bool"` - String string `field:"_string"` - - Date time.Time `field:"_date"` - DateN *time.Time `field:"_nildate"` - DateP *time.Time `field:"_ptrdate"` - Time time.Duration `field:"_time"` + Uint uint `db:"_uint"` + Uint8 uint8 `db:"_uint8"` + Uint16 uint16 `db:"_uint16"` + Uint32 uint32 `db:"_uint32"` + Uint64 uint64 `db:"_uint64"` + + Int int `db:"_int"` + Int8 int8 `db:"_int8"` + Int16 int16 `db:"_int16"` + Int32 int32 `db:"_int32"` + Int64 int64 `db:"_int64"` + + Float32 float32 `db:"_float32"` + Float64 float64 `db:"_float64"` + + Bool bool `db:"_bool"` + String string `db:"_string"` + + Date time.Time `db:"_date"` + DateN *time.Time `db:"_nildate"` + DateP *time.Time `db:"_ptrdate"` + DateD *time.Time `db:"_defaultdate,omitempty"` + Time int64 `db:"_time"` +} + +type artistType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` } type artistWithInt64Key struct { @@ -111,8 +118,8 @@ func (item itemWithKey) Constraint() db.Cond { func (item *itemWithKey) SetID(keys map[string]interface{}) error { if len(keys) == 2 { - item.Code = keys["code"].(string) - item.UserID = keys["user_id"].(string) + item.Code = string(keys["code"].([]byte)) + item.UserID = string(keys["user_id"].([]byte)) return nil } return errors.New(`Expecting exactly two keys.`) @@ -121,8 +128,14 @@ func (item *itemWithKey) SetID(keys map[string]interface{}) error { var testValues testValuesStruct func init() { + loc, err := time.LoadLocation(testTimeZone) - t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) + if err != nil { + panic(err.Error()) + } + + t := time.Date(2011, 7, 28, 1, 2, 3, 0, loc) // timestamp with time zone + tnz := time.Date(2012, 7, 28, 1, 2, 3, 0, time.FixedZone("", 0)) // timestamp without time zone testValues = testValuesStruct{ 1, 1, 1, 1, 1, @@ -132,18 +145,22 @@ func init() { "Hello world!", t, nil, - &t, - time.Second * time.Duration(7331), + &tnz, + nil, + int64(time.Second * time.Duration(7331)), + } + + if host = os.Getenv("TEST_HOST"); host == "" { + host = "localhost" } - flag.Parse() - settings.Address = db.ParseAddress(*host) + settings.Address = db.ParseAddress(host) } -// Loggin some information to stdout (like the SQL query and its +// Logging some information to stdout (like the SQL query and its // arguments), useful for development. func TestEnableDebug(t *testing.T) { - os.Setenv(db.EnvEnableDebug, "TRUE") + // os.Setenv(db.EnvEnableDebug, "TRUE") } // Attempts to open an empty datasource. @@ -158,14 +175,14 @@ func TestOpenFailed(t *testing.T) { } // Attempts to open an empty datasource. -func TestOpenWithWrongData(t *testing.T) { +func SkipTestOpenWithWrongData(t *testing.T) { var err error var rightSettings, wrongSettings db.Settings // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, - Host: *host, + Database: databaseName, + Host: host, User: username, Password: password, } @@ -178,9 +195,9 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, - Host: *host, - User: username, + Database: "fail", + Host: host, + User: "fail", Password: "fail", } @@ -191,7 +208,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong database. wrongSettings = db.Settings{ Database: "fail", - Host: *host, + Host: host, User: username, Password: password, } @@ -202,8 +219,8 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, - Host: *host, + Database: databaseName, + Host: host, User: "fail", Password: password, } @@ -219,10 +236,10 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, - Host: *host, + Host: host, } // Opening database. @@ -234,7 +251,7 @@ func TestOldSettings(t *testing.T) { sess.Close() } -// Test USE +// Test Use func TestUse(t *testing.T) { var err error var sess db.Database @@ -246,7 +263,7 @@ func TestUse(t *testing.T) { // Connecting to another database, error expected. if err = sess.Use("Another database"); err == nil { - t.Fatal("This database does not exists!") + t.Fatal("This database should not exist!") } // Closing connection. @@ -346,7 +363,7 @@ func TestAppend(t *testing.T) { t.Fatal(err) } - if to.Int64(id) == 0 { + if pk, ok := id.(int64); !ok || pk == 0 { t.Fatalf("Expecting an ID.") } @@ -361,7 +378,7 @@ func TestAppend(t *testing.T) { t.Fatal(err) } - if to.Int64(id) == 0 { + if pk, ok := id.(int64); !ok || pk == 0 { t.Fatalf("Expecting an ID.") } @@ -376,7 +393,7 @@ func TestAppend(t *testing.T) { t.Fatal(err) } - if to.Int64(id) == 0 { + if pk, ok := id.(int64); !ok || pk == 0 { t.Fatalf("Expecting an ID.") } @@ -466,39 +483,10 @@ func TestResultFetch(t *testing.T) { } if err == nil { - if to.Int64(rowMap["id"]) == 0 { - t.Fatalf("Expecting a not null ID.") - } - if to.String(rowMap["name"]) == "" { - t.Fatalf("Expecting a name.") - } - } else { - t.Fatal(err) - } - } - - res.Close() - - // Dumping into an struct with no tags. - rowStruct := struct { - ID uint64 - Name string - }{} - - res = artist.Find() - - for { - err = res.Next(&rowStruct) - - if err == db.ErrNoMoreRows { - break - } - - if err == nil { - if rowStruct.ID == 0 { + if id, ok := rowMap["id"].(int64); !ok || id == 0 { t.Fatalf("Expecting a not null ID.") } - if rowStruct.Name == "" { + if name, ok := rowMap["name"].([]byte); !ok || len(name) == 0 { t.Fatalf("Expecting a name.") } } else { @@ -510,8 +498,8 @@ func TestResultFetch(t *testing.T) { // Dumping into a tagged struct. rowStruct2 := struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find() @@ -537,7 +525,7 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an slice of maps. + // Dumping into a slice of maps. allRowsMap := []map[string]interface{}{} res = artist.Find() @@ -550,16 +538,15 @@ func TestResultFetch(t *testing.T) { } for _, singleRowMap := range allRowsMap { - if to.Int64(singleRowMap["id"]) == 0 { + if pk, ok := singleRowMap["id"].(int64); !ok || pk == 0 { t.Fatalf("Expecting a not null ID.") } } - // Dumping into an slice of structs. - + // Dumping into a slice of structs. allRowsStruct := []struct { - ID uint64 - Name string + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` }{} res = artist.Find() @@ -577,10 +564,10 @@ func TestResultFetch(t *testing.T) { } } - // Dumping into an slice of tagged structs. + // Dumping into a slice of tagged structs. allRowsStruct2 := []struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find() @@ -600,6 +587,164 @@ func TestResultFetch(t *testing.T) { } } +func TestResultFetchOne(t *testing.T) { + var err error + var sess db.Database + var artist db.Collection + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if artist, err = sess.Collection("artist"); err != nil { + t.Fatal(err) + } + + // Fetching one struct + var someArtist artistType + err = artist.Find().Limit(1).One(&someArtist) + if err != nil { + t.Fatal(err) + } + + if someArtist.Name == "" { + t.Fatal("Expecting an artist object with a name.") + } + if someArtist.ID <= 0 { + t.Fatal("Expecting an artist to have an ID.") + } + + // Fetching one object + var someArtistObj *artistType + err = artist.Find().Limit(1).One(&someArtistObj) + if err != nil { + t.Fatal(err) + } + + if someArtistObj.Name == "" { + t.Fatal("Expecting an artist object with a name.") + } + if someArtistObj.ID <= 0 { + t.Fatal("Expecting an artist object to have an ID.") + } +} + +func TestResultFetchAll(t *testing.T) { + var err error + var sess db.Database + var artist db.Collection + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if artist, err = sess.Collection("artist"); err != nil { + t.Fatal(err) + } + + // Fetching all artists into struct + artists := []artistType{} + err = artist.Find().All(&artists) + if err != nil { + t.Fatal(err) + } + + if len(artists) == 0 { + t.Fatal("Expecting some artists.") + } + if artists[0].Name == "" { + t.Fatal("Expecting the first artist to have a name.") + } + if artists[0].ID <= 0 { + t.Fatal("Expecting the first artist to have an ID.") + } + + // Fetching all artists into struct objects + artistObjs := []*artistType{} + err = artist.Find().All(&artistObjs) + if err != nil { + t.Fatal(err) + } + + if len(artistObjs) == 0 { + t.Fatal("Expecting some artist objects.") + } + if artistObjs[0].Name == "" { + t.Fatal("Expecting the first artist object to have a name.") + } + if artistObjs[0].ID <= 0 { + t.Fatal("Expecting the first artist object to have an ID.") + } +} + +func TestInlineStructs(t *testing.T) { + var sess db.Database + var err error + + var review db.Collection + + type reviewTypeDetails struct { + Name string `db:"name"` + Comments string `db:"comments"` + Created time.Time `db:"created"` + } + + type reviewType struct { + ID int64 `db:"id,omitempty"` + PublicationID int64 `db:"publication_id"` + Details reviewTypeDetails `db:",inline"` + } + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if review, err = sess.Collection("review"); err != nil { + t.Fatal(err) + } + + if err = review.Truncate(); err != nil { + t.Fatal(err) + } + + rec := reviewType{ + PublicationID: 123, + Details: reviewTypeDetails{ + Name: "..name..", Comments: "..comments..", + }, + } + + id, err := review.Append(rec) + if err != nil { + t.Fatal(err) + } + if id.(int64) <= 0 { + t.Fatal("bad id") + } + rec.ID = id.(int64) + + var recChk reviewType + err = review.Find().One(&recChk) + + if err != nil { + t.Fatal(err) + } + + if recChk.ID != rec.ID { + t.Fatal("ID of review does not match, expecting:", rec.ID, "got:", recChk.ID) + } + if recChk.Details.Name != rec.Details.Name { + t.Fatal("Name of inline field does not match, expecting:", + rec.Details.Name, "got:", recChk.Details.Name) + } +} + // Attempts to modify previously added rows. func TestUpdate(t *testing.T) { var err error @@ -618,8 +763,8 @@ func TestUpdate(t *testing.T) { // Defining destination struct value := struct { - ID uint64 - Name string + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` }{} // Getting the first artist. @@ -650,7 +795,7 @@ func TestUpdate(t *testing.T) { // Updating set with a struct rowStruct := struct { - Name string + Name string `db:"name"` }{strings.ToLower(value.Name)} if err = res.Update(rowStruct); err != nil { @@ -670,7 +815,7 @@ func TestUpdate(t *testing.T) { // Updating set with a tagged struct rowStruct2 := struct { Value1 string `db:"name"` - }{strings.Replace(value.Name, "z", "Z", -1)} + }{"john"} if err = res.Update(rowStruct2); err != nil { t.Fatal(err) @@ -685,6 +830,25 @@ func TestUpdate(t *testing.T) { if value.Name != rowStruct2.Value1 { t.Fatalf("Expecting a modification.") } + + // Updating set with a tagged object + rowStruct3 := &struct { + Value1 string `db:"name"` + }{"anderson"} + + if err = res.Update(rowStruct3); err != nil { + t.Fatal(err) + } + + // Pulling it again. + if err = res.One(&value); err != nil { + t.Fatal(err) + } + + // Verifying + if value.Name != rowStruct3.Value1 { + t.Fatalf("Expecting a modification.") + } } // Attempts to use functions within database queries. @@ -837,7 +1001,7 @@ func TestNullableFields(t *testing.T) { // In PostgreSQL, how we can tell if this is an invalid null? // if test.NullStringTest.Valid { - // t.Fatalf(`Expecting invalid null.`) + // t.Fatalf(`Expecting invalid null.`) // } // Testing insertion of valid nulls. @@ -960,11 +1124,6 @@ func TestRawRelations(t *testing.T) { var publication db.Collection var review db.Collection - type artistType struct { - ID int64 `db:"id,omitempty"` - Name string `db:"name"` - } - type publicationType struct { ID int64 `db:"id,omitempty"` Title string `db:"title"` @@ -1131,9 +1290,9 @@ func TestRawRelations(t *testing.T) { func TestRawQuery(t *testing.T) { var sess db.Database - var rows *sql.Rows + var rows *sqlx.Rows var err error - var drv *sql.DB + var drv *sqlx.DB type publicationType struct { ID int64 `db:"id,omitempty"` @@ -1147,19 +1306,19 @@ func TestRawQuery(t *testing.T) { defer sess.Close() - drv = sess.Driver().(*sql.DB) + drv = sess.Driver().(*sqlx.DB) - rows, err = drv.Query(` - SELECT - p.id, - p.title AS publication_title, - a.name AS artist_name - FROM - artist AS a, - publication AS p - WHERE - a.id = p.author_id - `) + rows, err = drv.Queryx(` + SELECT + p.id, + p.title AS publication_title, + a.name AS artist_name + FROM + artist AS a, + publication AS p + WHERE + a.id = p.author_id + `) if err != nil { t.Fatal(err) @@ -1181,11 +1340,6 @@ func TestTransactionsAndRollback(t *testing.T) { var sess db.Database var err error - type artistType struct { - ID int64 `db:"id,omitempty"` - Name string `db:"name"` - } - if sess, err = db.Open(Adapter, settings); err != nil { t.Fatal(err) } @@ -1402,6 +1556,8 @@ func TestCompositeKeys(t *testing.T) { // then it tries to get the stored datatypes and check if the stored and the // original values match. func TestDataTypes(t *testing.T) { + // os.Setenv(db.EnvEnableDebug, "TRUE") + var res db.Result var sess db.Database var dataTypes db.Collection @@ -1444,14 +1600,300 @@ func TestDataTypes(t *testing.T) { // Trying to dump the subject into an empty structure of the same type. var item testValuesStruct - res.One(&item) + if err = res.One(&item); err != nil { + t.Fatal(err) + } + + if item.DateD == nil { + t.Fatal("Expecting default date to have been set on append") + } + + // Copy the default date (this value is set by the database) + testValues.DateD = item.DateD // The original value and the test subject must match. if reflect.DeepEqual(item, testValues) == false { + fmt.Printf("item1: %v\n", item) + fmt.Printf("test2: %v\n", testValues) t.Fatalf("Struct is different.") } } +func TestOptionTypes(t *testing.T) { + var err error + var sess db.Database + var optionTypes db.Collection + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if optionTypes, err = sess.Collection("option_types"); err != nil { + t.Fatal(err) + } + + if err = optionTypes.Truncate(); err != nil { + t.Fatal(err) + } + + // TODO: lets do some benchmarking on these auto-wrapped option types.. + + // TODO: add nullable jsonb field mapped to a []string + + // A struct with wrapped option types defined in the struct tags + // for postgres string array and jsonb types + type optionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags,stringarray"` + Settings map[string]interface{} `db:"settings,jsonb"` + } + + // Item 1 + item1 := optionType{ + Name: "Food", + Tags: []string{"toronto", "pizza"}, + Settings: map[string]interface{}{"a": 1, "b": 2}, + } + + id, err := optionTypes.Append(item1) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1Chk optionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1Chk); err != nil { + t.Fatal(err) + } + + if item1Chk.Settings["a"].(float64) != 1 { // float64 because of json.. + t.Fatalf("Expecting Settings['a'] of jsonb value to be 1") + } + + if item1Chk.Tags[0] != "toronto" { + t.Fatalf("Expecting first element of Tags stringarray to be 'toronto'") + } + + // Item 1 B + item1b := &optionType{ + Name: "Golang", + Tags: []string{"love", "it"}, + Settings: map[string]interface{}{"go": 1, "lang": 2}, + } + + id, err = optionTypes.Append(item1b) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1bChk optionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1bChk); err != nil { + t.Fatal(err) + } + + if item1bChk.Settings["go"].(float64) != 1 { // float64 because of json.. + t.Fatalf("Expecting Settings['go'] of jsonb value to be 1") + } + + if item1bChk.Tags[0] != "love" { + t.Fatalf("Expecting first element of Tags stringarray to be 'love'") + } + + // Item 1 C + item1c := &optionType{ + Name: "Sup", Tags: []string{}, Settings: map[string]interface{}{}, + } + + id, err = optionTypes.Append(item1c) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1cChk optionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1cChk); err != nil { + t.Fatal(err) + } + + if len(item1cChk.Tags) != 0 { + t.Fatalf("Expecting tags array to be empty but is %v", item1cChk.Tags) + } + + if len(item1cChk.Settings) != 0 { + t.Fatalf("Expecting Settings map to be empty") + } + + // An option type to pointer jsonb field + type optionType2 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags,stringarray"` + Settings *map[string]interface{} `db:"settings,jsonb"` + } + + item2 := optionType2{ + Name: "JS", Tags: []string{"hi", "bye"}, Settings: nil, + } + + id, err = optionTypes.Append(item2) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item2Chk optionType2 + res := optionTypes.Find(db.Cond{"id": id}) + if err := res.One(&item2Chk); err != nil { + t.Fatal(err) + } + + if item2Chk.ID != id.(int64) { + t.Fatalf("Expecting id to match") + } + + if item2Chk.Name != item2.Name { + t.Fatalf("Expecting Name to match") + } + + if item2Chk.Tags[0] != item2.Tags[0] || len(item2Chk.Tags) != len(item2.Tags) { + t.Fatalf("Expecting tags to match") + } + + // Update the value + m := map[string]interface{}{} + m["lang"] = "javascript" + m["num"] = 31337 + item2.Settings = &m + err = res.Update(item2) + if err != nil { + t.Fatal(err) + } + + if err := res.One(&item2Chk); err != nil { + t.Fatal(err) + } + + if (*item2Chk.Settings)["num"].(float64) != 31337 { // float64 because of json.. + t.Fatalf("Expecting Settings['num'] of jsonb value to be 31337") + } + + if (*item2Chk.Settings)["lang"] != "javascript" { + t.Fatalf("Expecting Settings['lang'] of jsonb value to be 'javascript'") + } + + // An option type to pointer string array field + type optionType3 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags *[]string `db:"tags,stringarray"` + Settings map[string]interface{} `db:"settings,jsonb"` + } + + item3 := optionType3{ + Name: "Julia", Tags: nil, Settings: map[string]interface{}{"girl": true, "lang": true}, + } + + id, err = optionTypes.Append(item3) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item3Chk optionType2 + if err := optionTypes.Find(db.Cond{"id": id}).One(&item3Chk); err != nil { + t.Fatal(err) + } +} + +func TestOptionTypeJsonbStruct(t *testing.T) { + var err error + var sess db.Database + var optionTypes db.Collection + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if optionTypes, err = sess.Collection("option_types"); err != nil { + t.Fatal(err) + } + + if err = optionTypes.Truncate(); err != nil { + t.Fatal(err) + } + + // A struct with wrapped option types defined in the struct tags + // for postgres string array and jsonb types + type Settings struct { + Name string `json:"name"` + Num int64 `json:"num"` + } + + type OptionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags,stringarray"` + Settings Settings `db:"settings,jsonb"` + } + + item1 := &OptionType{ + Name: "Hi", + Tags: []string{"aah", "ok"}, + Settings: Settings{Name: "a", Num: 123}, + } + + id, err := optionTypes.Append(item1) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1Chk OptionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1Chk); err != nil { + t.Fatal(err) + } + + if len(item1Chk.Tags) != 2 { + t.Fatalf("Expecting 2 tags") + } + + if item1Chk.Tags[0] != "aah" { + t.Fatalf("Expecting first tag to be 0") + } + + if item1Chk.Settings.Name != "a" { + t.Fatalf("Expecting Name to be 'a'") + } + + if item1Chk.Settings.Num != 123 { + t.Fatalf("Expecting Num to be 123") + } +} + // We are going to benchmark the engine, so this is no longed needed. func TestDisableDebug(t *testing.T) { os.Setenv(db.EnvEnableDebug, "") @@ -1468,7 +1910,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec(`TRUNCATE TABLE "artist"`); err != nil { b.Fatal(err) @@ -1522,7 +1964,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/postgresql/layout.go b/postgresql/layout.go deleted file mode 100644 index 13c24b3667e85ea5501adf98259668a7f6eeb6ea..0000000000000000000000000000000000000000 --- a/postgresql/layout.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package postgresql - -const ( - pgsqlColumnSeparator = `.` - pgsqlIdentifierSeparator = `, ` - pgsqlIdentifierQuote = `"{{.Raw}}"` - pgsqlValueSeparator = `, ` - pgsqlValueQuote = `'{{.}}'` - pgsqlAndKeyword = `AND` - pgsqlOrKeyword = `OR` - pgsqlNotKeyword = `NOT` - pgsqlDescKeyword = `DESC` - pgsqlAscKeyword = `ASC` - pgsqlDefaultOperator = `=` - pgsqlClauseGroup = `({{.}})` - pgsqlClauseOperator = ` {{.}} ` - pgsqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - pgsqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - pgsqlOrderByLayout = ` - {{if .SortColumns}} - ORDER BY {{.SortColumns}} - {{end}} - ` - - pgsqlWhereLayout = ` - {{if .Conds}} - WHERE {{.Conds}} - {{end}} - ` - - pgsqlSelectLayout = ` - SELECT - - {{if .Columns}} - {{.Columns}} - {{else}} - * - {{end}} - - {{if .Table}} - FROM {{.Table}} - {{end}} - - {{.Where}} - - {{.GroupBy}} - - {{.OrderBy}} - - {{if .Limit}} - LIMIT {{.Limit}} - {{end}} - - {{if .Offset}} - OFFSET {{.Offset}} - {{end}} - ` - pgsqlDeleteLayout = ` - DELETE - FROM {{.Table}} - {{.Where}} - ` - pgsqlUpdateLayout = ` - UPDATE - {{.Table}} - SET {{.ColumnValues}} - {{ .Where }} - ` - - pgsqlSelectCountLayout = ` - SELECT - COUNT(1) AS _t - FROM {{.Table}} - {{.Where}} - - {{if .Limit}} - LIMIT {{.Limit}} - {{end}} - - {{if .Offset}} - OFFSET {{.Offset}} - {{end}} - ` - - pgsqlInsertLayout = ` - INSERT INTO {{.Table}} - ({{.Columns}}) - VALUES - ({{.Values}}) - {{.Extra}} - ` - - pgsqlTruncateLayout = ` - TRUNCATE TABLE {{.Table}} RESTART IDENTITY - ` - - pgsqlDropDatabaseLayout = ` - DROP DATABASE {{.Database}} - ` - - pgsqlDropTableLayout = ` - DROP TABLE {{.Table}} - ` - - pgsqlGroupByLayout = ` - {{if .GroupColumns}} - GROUP BY {{.GroupColumns}} - {{end}} - ` - - psqlNull = `NULL` -) diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go new file mode 100644 index 0000000000000000000000000000000000000000..f15ef8703a0ba6e31001e4d63960fee0e6088437 --- /dev/null +++ b/postgresql/postgresql.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package postgresql // import "upper.io/v2/db/postgresql" + +import ( + "upper.io/cache" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `postgresql` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/postgresql/template.go b/postgresql/template.go new file mode 100644 index 0000000000000000000000000000000000000000..aa414ecbb85e1065fa52a2145eb403ac69e6c690 --- /dev/null +++ b/postgresql/template.go @@ -0,0 +1,135 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package postgresql + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + + {{if .Columns}} + {{.Columns}} + {{else}} + * + {{end}} + + {{if .Table}} + FROM {{.Table}} + {{end}} + + {{.Where}} + + {{.GroupBy}} + + {{.OrderBy}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table}} + {{.Where}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table}} + SET {{.ColumnValues}} + {{ .Where }} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table}} + {{.Where}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table}} + ({{.Columns}}) + VALUES + ({{.Values}}) + {{.Extra}} + ` + + adapterTruncateLayout = ` + TRUNCATE TABLE {{.Table}} RESTART IDENTITY + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) diff --git a/postgresql/tx.go b/postgresql/tx.go deleted file mode 100644 index 1970e35c2f1a280a50e24a126f4686f0e756695a..0000000000000000000000000000000000000000 --- a/postgresql/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package postgresql - -import ( - "database/sql" -) - -type tx struct { - *source - sqlTx *sql.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/ql/_dumps/Makefile b/ql/_dumps/Makefile index cba56b0150415fb40b311ce885b8d59a497ff8b8..1703af5222b0476e569a542c0cc26cdefd13489d 100644 --- a/ql/_dumps/Makefile +++ b/ql/_dumps/Makefile @@ -1,3 +1,3 @@ all: rm -f test.db - cat structs.sql | ql -db test.db + cat structs.sql | $$GOPATH/bin/ql -db test.db diff --git a/ql/_dumps/structs.sql b/ql/_dumps/structs.sql index 96ca94f519fc18d7c4a6620f68396d8415f9e654..373389a0d01e249423d7589f98507db9db9343f0 100644 --- a/ql/_dumps/structs.sql +++ b/ql/_dumps/structs.sql @@ -43,6 +43,7 @@ CREATE TABLE data_types ( _date time, _nildate time, _ptrdate time, + _defaultdate time, _time time ); diff --git a/ql/_example/main.go b/ql/_example/main.go index 297caef56d00518a898a1cbf1e4b1e84658e88b0..33a1c75fb18aec7ba7b49881ff6d056d19067196 100644 --- a/ql/_example/main.go +++ b/ql/_example/main.go @@ -5,8 +5,8 @@ import ( "log" "time" - "upper.io/db" // Imports the main db package. - _ "upper.io/db/ql" // Imports the ql adapter. + "upper.io/v2/db" // Imports the main db package. + _ "upper.io/v2/db/ql" // Imports the ql adapter. ) var settings = db.Settings{ diff --git a/ql/collection.go b/ql/collection.go index d50dd3ea1652d24ac8405851c8d1926b42c73121..68b9cb3b7837bfcabb423b5c64220b0e9069a9df 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,231 +22,68 @@ package ql import ( - "fmt" + "database/sql" "reflect" "strings" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/result" ) -const defaultOperator = `==` - type table struct { sqlutil.T - columnTypes map[string]reflect.Kind - source *source + *database names []string + columnTypes map[string]reflect.Kind } -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} +var _ = db.Collection(&table{}) +// Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: t, - where: where, - arguments: arguments, - } - - return result + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - - _, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{t.tableN(0)}, + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. +// Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - cols, vals, err := t.FieldValues(item, toInternal) - - var columns sqlgen.Columns - var values sqlgen.Values + columnNames, columnValues, err := t.FieldValues(item) - for _, col := range cols { - columns = append(columns, sqlgen.Column{col}) + if err != nil { + return nil, err } - for i := 0; i < len(vals); i++ { - values = append(values, sqlPlaceholder) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - // Error ocurred, stop appending. if err != nil { return nil, err } - res, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{t.tableN(0)}, - Columns: columns, - Values: values, - }, vals...) + stmt := sqlgen.Statement{ + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, + } - if err != nil { + var res sql.Result + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -263,18 +100,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { return id, nil } -// Returns true if the collection exists. +// Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) -} - -func toInternal(v interface{}) interface{} { - return v + return strings.Join(t.Tables, `, `) } diff --git a/ql/database.go b/ql/database.go index 0ba62d4e990300ef33a766f2ca1818f3fb57ef5e..0892610d2c5ee1bac9df035c95077eb27f6a3992 100644 --- a/ql/database.go +++ b/ql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -24,127 +24,250 @@ package ql import ( "database/sql" "fmt" - "os" - "reflect" "strings" "time" - // Importing QL driver - _ "github.com/cznic/ql/driver" - "upper.io/cache" - "upper.io/db" - "upper.io/db/util/schema" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -// Public adapters name under which this adapter registers its. -const Adapter = `ql` -var ( - // DateFormat defines the format used for storing dates. - DateFormat = "2006-01-02 15:04:05.000" - // TimeFormat defines the format used for storing time values. - TimeFormat = "%d:%02d:%02d.%03d" - timeType = reflect.TypeOf(time.Time{}).Kind() + _ "github.com/cznic/ql/driver" // QL driver + "github.com/jmoiron/sqlx" + "upper.io/v2/db" + "upper.io/v2/db/util/schema" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/tx" ) -var template *sqlgen.Template - var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL - session *sql.DB - tx *tx + session *sqlx.DB + tx *sqltx.Tx schema *schema.DatabaseSchema } +type tx struct { + *sqltx.Tx + *database +} + +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) + type columnSchemaT struct { Name string `db:"Name"` } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session +} + +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error + + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + Database: settings.Database, + } + + d.connURL = conn } - return false + + if d.session, err = sqlx.Open(`ql`, d.connURL.String()); err != nil { + return err + } + + d.session.Mapper = sqlutil.NewMapper() + + if err = d.populateSchema(); err != nil { + return err + } + + return nil +} + +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +func (d *database) clone() (adapter *database, err error) { + adapter = new(database) + + if err = adapter.Setup(d.connURL); err != nil { + return nil, err } + + return adapter, nil } -func init() { - - template = &sqlgen.Template{ - qlColumnSeparator, - qlIdentifierSeparator, - qlIdentifierQuote, - qlValueSeparator, - qlValueQuote, - qlAndKeyword, - qlOrKeyword, - qlNotKeyword, - qlDescKeyword, - qlAscKeyword, - qlDefaultOperator, - qlClauseGroup, - qlClauseOperator, - qlColumnValue, - qlTableAliasLayout, - qlColumnAliasLayout, - qlSortByColumnLayout, - qlWhereLayout, - qlOrderByLayout, - qlInsertLayout, - qlSelectLayout, - qlUpdateLayout, - qlDeleteLayout, - qlTruncateLayout, - qlDropDatabaseLayout, - qlDropTableLayout, - qlSelectCountLayout, - qlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -func (s *source) populateSchema() (err error) { - var collections []string +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil +} + +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error + + if len(names) == 0 { + return nil, db.ErrMissingCollectionName + } + + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } - s.schema = schema.NewDatabaseSchema() + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper + + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) + + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } + + tableName := chunks[0] + + if err := d.tableExists(tableName); err != nil { + return nil, err + } + + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } + } + return col, nil +} + +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { + + tablesInSchema := len(d.schema.Tables) + + // Is schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil + } + + // Schema is empty. + + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Table`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + ), + } + + // Executing statement. + var rows *sqlx.Rows + if rows, err = d.Query(stmt); err != nil { + return nil, err + } + + defer rows.Close() + + collections = []string{} + + var name string + + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err + } + + // Adding table entry to schema. + d.schema.AddTable(name) + + // Adding table to collections array. + collections = append(collections, name) + } + + return collections, nil +} + +// Use changes the active database. +func (d *database) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } - s.schema.Name = conn.Database + conn.Database = database - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { - return err + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + return db.ErrUnsupported +} + +// Setup stores database settings. +func (d *database) Setup(conn db.ConnectionURL) error { + d.connURL = conn + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name +} + +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err - } + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err } - return err + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -154,26 +277,26 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - var tx *sql.Tx + var tx *sqlx.Tx - if tx, err = s.session.Begin(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -189,8 +312,9 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { - var rows *sql.Rows +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows var query string var err error var start, end int64 @@ -199,30 +323,30 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - rows, err = s.tx.sqlTx.Query(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - var tx *sql.Tx + var tx *sqlx.Tx - if tx, err = s.session.Begin(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } - if rows, err = tx.Query(query, args...); err != nil { + if rows, err = tx.Queryx(query, args...); err != nil { return nil, err } @@ -234,9 +358,10 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string - var row *sql.Row + var row *sqlx.Row var err error var start, end int64 @@ -244,30 +369,30 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - row = s.tx.sqlTx.QueryRow(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - var tx *sql.Tx + var tx *sqlx.Tx - if tx, err = s.session.Begin(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } - if row = tx.QueryRow(query, args...); err != nil { + if row = tx.QueryRowx(query, args...); err != nil { return nil, err } @@ -279,192 +404,64 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro return row, err } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} - -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) clone() (adapter *source, err error) { - adapter = new(source) - - if err = adapter.Setup(s.connURL); err != nil { - return nil, err - } - - return adapter, nil -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} - -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sql.Tx - - if clone, err = s.clone(); err != nil { - return nil, err - } - - if sqlTx, err = s.session.Begin(); err != nil { - return nil, err - } - - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(conn db.ConnectionURL) error { - s.connURL = conn - return s.Open() -} - -// Returns the underlying *sql.DB instance. -func (s *source) Driver() interface{} { - return s.session -} - -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { - var err error - - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - Database: settings.Database, - } - - s.connURL = conn - } - - if s.session, err = sql.Open(`ql`, s.connURL.String()); err != nil { - return err - } - - if err = s.populateSchema(); err != nil { - return err - } +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - return nil -} + d.schema = schema.NewDatabaseSchema() -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} - -// Changes the active database. -func (s *source) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } - conn.Database = database - - s.connURL = conn - - return s.Open() -} - -// Drops the currently active database. -func (s *source) Drop() error { - return db.ErrUnsupported -} - -// Returns a list of all tables within the currently active database. -func (s *source) Collections() (collections []string, err error) { - - tablesInSchema := len(s.schema.Tables) - - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil - } - - // Schema is empty. + d.schema.Name = conn.Database - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Table`}, - Columns: sqlgen.Columns{ - {`Name`}, - }, - } - - // Executing statement. - var rows *sql.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { + return err } - defer rows.Close() - - collections = []string{} - - var name string - - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err } - - // Adding table entry to schema. - s.schema.AddTable(name) - - // Adding table to collections array. - collections = append(collections, name) } - return collections, nil + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error - var rows *sql.Rows + var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Table`}, - Columns: sqlgen.Columns{ - {`Name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`Name`}, `==`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Table`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`Name`), + Operator: `==`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, names[i]); err != nil { + if rows, err = d.Query(stmt, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -478,31 +475,35 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Column`}, - Columns: sqlgen.Columns{ - {`Name`}, - {`Type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`TableName`}, `==`, sqlPlaceholder}, - }, - } - - var rows *sql.Rows + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Column`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + sqlgen.ColumnWithName(`Type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`TableName`), + Operator: `==`, + Value: sqlPlaceholder, + }, + ), + } + + var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, tableName); err != nil { + if rows, err = d.Query(stmt, tableName); err != nil { return nil, err } @@ -512,51 +513,11 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil -} - -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } - - col := &table{ - source: s, - names: names, - } - - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) - - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName - } - - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } - - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err - } - } - - return col, nil + return d.schema.TableInfo[tableName].Columns, nil } diff --git a/ql/database_test.go b/ql/database_test.go index 01c65acfde1f8c2fb63c282a94e73ba63fa20ec4..861dfedfe5184284edbc3d58e64526774640a44a 100644 --- a/ql/database_test.go +++ b/ql/database_test.go @@ -40,42 +40,48 @@ import ( "testing" "time" - "upper.io/db" - "upper.io/db/util/sqlutil" + "github.com/jmoiron/sqlx" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlutil" ) const ( - database = `_dumps/test.db` + databaseName = `_dumps/test.db` +) + +const ( + testTimeZone = "Canada/Eastern" ) var settings = db.Settings{ - Database: database, + Database: databaseName, } // Structure for testing conversions and datatypes. type testValuesStruct struct { - Uint uint `field:"_uint"` - Uint8 uint8 `field:"_uint8"` - Uint16 uint16 `field:"_uint16"` - Uint32 uint32 `field:"_uint32"` - Uint64 uint64 `field:"_uint64"` - - Int int `field:"_int"` - Int8 int8 `field:"_int8"` - Int16 int16 `field:"_int16"` - Int32 int32 `field:"_int32"` - Int64 int64 `field:"_int64"` - - Float32 float32 `field:"_float32"` - Float64 float64 `field:"_float64"` - - Bool bool `field:"_bool"` - String string `field:"_string"` - - Date time.Time `field:"_date"` - DateN *time.Time `field:"_nildate"` - DateP *time.Time `field:"_ptrdate"` - Time time.Duration `field:"_time"` + Uint uint `db:"_uint"` + Uint8 uint8 `db:"_uint8"` + Uint16 uint16 `db:"_uint16"` + Uint32 uint32 `db:"_uint32"` + Uint64 uint64 `db:"_uint64"` + + Int int `db:"_int"` + Int8 int8 `db:"_int8"` + Int16 int16 `db:"_int16"` + Int32 int32 `db:"_int32"` + Int64 int64 `db:"_int64"` + + Float32 float32 `db:"_float32"` + Float64 float64 `db:"_float64"` + + Bool bool `db:"_bool"` + String string `db:"_string"` + + Date time.Time `db:"_date"` + DateN *time.Time `db:"_nildate"` + DateP *time.Time `db:"_ptrdate"` + DateD *time.Time `db:"_defaultdate,omitempty"` + Time int64 `db:"_time"` } type itemWithKey struct { @@ -103,7 +109,14 @@ func (item *itemWithKey) SetID(keys map[string]interface{}) error { var testValues testValuesStruct func init() { - t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) + loc, err := time.LoadLocation(testTimeZone) + + if err != nil { + panic(err.Error()) + } + + t := time.Date(2011, 7, 28, 1, 2, 3, 0, loc) + tnz := time.Date(2012, 7, 28, 1, 2, 3, 0, time.UTC) testValues = testValuesStruct{ 1, 1, 1, 1, 1, @@ -113,8 +126,9 @@ func init() { "Hello world!", t, nil, - &t, - time.Second * time.Duration(7331), + &tnz, + nil, + int64(time.Second * time.Duration(7331)), } } @@ -141,7 +155,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, } // Opening database. @@ -417,39 +431,10 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an struct with no tags. - rowStruct := struct { - ID uint64 - Name string - }{} - - res = artist.Find().Select("id() AS id", "name") - - for { - err = res.Next(&rowStruct) - - if err == db.ErrNoMoreRows { - break - } - - if err == nil { - if rowStruct.ID == 0 { - t.Fatalf("Expecting a not null ID.") - } - if rowStruct.Name == "" { - t.Fatalf("Expecting a name.") - } - } else { - t.Fatal(err) - } - } - - res.Close() - // Dumping into a tagged struct. rowStruct2 := struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find().Select("id() AS id", "name") @@ -493,11 +478,10 @@ func TestResultFetch(t *testing.T) { } } - // Dumping into an slice of structs. - + // Dumping into a slice of structs. allRowsStruct := []struct { - ID uint64 - Name string + ID uint64 `db:"id"` + Name string `db:"name"` }{} res = artist.Find().Select("id() AS id", "name") @@ -517,8 +501,8 @@ func TestResultFetch(t *testing.T) { // Dumping into an slice of tagged structs. allRowsStruct2 := []struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find().Select("id() AS id", "name") @@ -588,7 +572,7 @@ func TestUpdate(t *testing.T) { // Updating set with a struct rowStruct := struct { - Name string + Name string `db:"name"` }{strings.ToLower(value.Name)} if err = res.Update(rowStruct); err != nil { @@ -925,9 +909,10 @@ func TestRawRelations(t *testing.T) { func TestRawQuery(t *testing.T) { var sess db.Database - var rows *sql.Rows + var rows *sqlx.Rows var err error - var drv *sql.DB + var drv *sqlx.DB + var tx *sqlx.Tx type publicationType struct { ID int64 `db:"id,omitempty"` @@ -941,9 +926,13 @@ func TestRawQuery(t *testing.T) { defer sess.Close() - drv = sess.Driver().(*sql.DB) + drv = sess.Driver().(*sqlx.DB) - rows, err = drv.Query(` + if tx, err = drv.Beginx(); err != nil { + t.Fatal(err) + } + + if rows, err = tx.Queryx(` SELECT p.id AS id, p.title AS publication_title, @@ -953,9 +942,11 @@ func TestRawQuery(t *testing.T) { (SELECT id() AS id, title, author_id FROM publication) AS p WHERE a.id == p.author_id - `) + `); err != nil { + t.Fatal(err) + } - if err != nil { + if err = tx.Commit(); err != nil { t.Fatal(err) } @@ -1276,7 +1267,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) @@ -1349,7 +1340,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/ql/ql.go b/ql/ql.go new file mode 100644 index 0000000000000000000000000000000000000000..ce4749546fe45045e70e5354d0d0567a01013312 --- /dev/null +++ b/ql/ql.go @@ -0,0 +1,72 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package ql // import "upper.io/v2/db/ql" + +import ( + "upper.io/cache" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `ql` + +var template *sqlutil.TemplateWithUtils + +func init() { + + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/ql/result.go b/ql/result.go deleted file mode 100644 index 8a0743f842b97a7b416068536d2e28b15fa36ce2..0000000000000000000000000000000000000000 --- a/ql/result.go +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "database/sql" - "fmt" - "strings" - - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"total"` -} - -type result struct { - table *table - cursor *sql.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() (err error) { - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = r.table.fetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - // Current cursor. - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - // Fetching the next result from the cursor. - if err = r.table.fetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values, toInternal) - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() error { - var err error - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts matching elements. -func (r *result) Count() (uint64, error) { - var count counter - - rows, err := r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - defer rows.Close() - if err = sqlutil.FetchRow(rows, &count); err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/ql/layout.go b/ql/template.go similarity index 65% rename from ql/layout.go rename to ql/template.go index d60f32e2d49b1bb55ba2f4dcae6dae061e31c609..61fc47c044905a107393e34539f62e3f19e1cebc 100644 --- a/ql/layout.go +++ b/ql/template.go @@ -22,37 +22,38 @@ package ql const ( - qlColumnSeparator = `.` - qlIdentifierSeparator = `, ` - qlIdentifierQuote = `{{.Raw}}` - qlValueSeparator = `, ` - qlValueQuote = `"{{.}}"` - qlAndKeyword = `&&` - qlOrKeyword = `||` - qlNotKeyword = `!=` - qlDescKeyword = `DESC` - qlAscKeyword = `ASC` - qlDefaultOperator = `==` - qlClauseGroup = `({{.}})` - qlClauseOperator = ` {{.}} ` - qlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - qlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - qlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - qlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - qlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `{{.Value}}` + adapterValueSeparator = `, ` + adapterValueQuote = `"{{.}}"` + adapterAndKeyword = `&&` + adapterOrKeyword = `||` + adapterNotKeyword = `!=` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `==` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - qlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - qlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -77,19 +78,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - qlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - qlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - qlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT count(1) AS total FROM {{.Table}} @@ -104,7 +105,7 @@ const ( {{end}} ` - qlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -112,19 +113,19 @@ const ( {{.Extra}} ` - qlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} ` - qlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - qlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - qlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} diff --git a/ql/tx.go b/ql/tx.go deleted file mode 100644 index 3cc3f38333ec15176061df414a370095938ebca8..0000000000000000000000000000000000000000 --- a/ql/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "database/sql" -) - -type tx struct { - *source - sqlTx *sql.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/ql/util.go b/ql/util.go deleted file mode 100644 index 5f421dd535e43fd6b1c6aea9025057d5f7098e5b..0000000000000000000000000000000000000000 --- a/ql/util.go +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "database/sql" - "reflect" - - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util" -) - -func (t *table) fetchRow(rows *sql.Rows, dst interface{}) error { - var err error - - dstv := reflect.ValueOf(dst) - - if dstv.IsNil() || dstv.Kind() != reflect.Ptr { - return db.ErrExpectingPointer - } - - itemV := dstv.Elem() - - next := rows.Next() - - if next == false { - if err = rows.Err(); err != nil { - return err - } - return db.ErrNoMoreRows - } - - var columns []string - - if columns, err = rows.Columns(); err != nil { - return err - } - - item, err := t.fetchResult(itemV.Type(), rows, columns) - - if err != nil { - return err - } - - itemV.Set(reflect.Indirect(item)) - - return nil -} - -func (t *table) fetchResult(itemT reflect.Type, rows *sql.Rows, columns []string) (item reflect.Value, err error) { - expecting := len(columns) - - scanArgs := make([]interface{}, expecting) - - switch itemT.Kind() { - case reflect.Struct: - // Creating new value of the expected type. - item = reflect.New(itemT) - // Pairing each column with its index. - for i, columnName := range columns { - index := util.GetStructFieldIndex(itemT, columnName) - if len(index) > 0 { - destF := item.Elem().FieldByIndex(index) - scanArgs[i] = destF.Addr().Interface() - } else { - var placeholder sql.RawBytes - scanArgs[i] = &placeholder - } - } - - err = rows.Scan(scanArgs...) - - if err != nil { - return item, err - } - case reflect.Map: - values := make([]*sql.RawBytes, len(columns)) - for i := range columns { - scanArgs[i] = &values[i] - } - err = rows.Scan(scanArgs...) - - if err == nil { - item = reflect.MakeMap(itemT) - for i, columnName := range columns { - valS := string(*values[i]) - - var vv reflect.Value - - if _, ok := t.columnTypes[columnName]; ok == true { - v, _ := to.Convert(valS, t.columnTypes[columnName]) - vv = reflect.ValueOf(v) - } else { - v, _ := to.Convert(valS, reflect.String) - vv = reflect.ValueOf(v) - } - - vk := reflect.ValueOf(columnName) - item.SetMapIndex(vk, vv) - } - } - - return item, err - default: - return item, db.ErrExpectingMapOrStruct - } - - return item, nil -} - -func (t *table) fetchRows(rows *sql.Rows, dst interface{}) error { - var err error - - // Destination. - dstv := reflect.ValueOf(dst) - - if dstv.IsNil() || dstv.Kind() != reflect.Ptr { - return db.ErrExpectingPointer - } - - if dstv.Elem().Kind() != reflect.Slice { - return db.ErrExpectingSlicePointer - } - - if dstv.Kind() != reflect.Ptr || dstv.Elem().Kind() != reflect.Slice || dstv.IsNil() { - return db.ErrExpectingSliceMapStruct - } - - slicev := dstv.Elem() - itemT := slicev.Type().Elem() - - var columns []string - - if columns, err = rows.Columns(); err != nil { - return err - } - - for rows.Next() { - - item, err := t.fetchResult(itemT, rows, columns) - - if err != nil { - return err - } - - slicev = reflect.Append(slicev, reflect.Indirect(item)) - } - - rows.Close() - - dstv.Elem().Set(slicev) - - return nil -} diff --git a/sqlite/README.md b/sqlite/README.md index 1c604de95e674c145b95bdeedf6d8bb2609a577f..747a79d69b9e8385e7c349957898e891be9fe599 100644 --- a/sqlite/README.md +++ b/sqlite/README.md @@ -1,6 +1,6 @@ -# SQLite3 adapter for upper.io/db +# SQLite3 adapter for upper.io/v2/db Please read the full docs, acknowledgements and examples at -[https://upper.io/db/sqlite][1] +[https://upper.io/v2/db/sqlite][1] -[1]: https://upper.io/db/sqlite +[1]: https://upper.io/v2/db/sqlite diff --git a/sqlite/_dumps/structs.sql b/sqlite/_dumps/structs.sql index 9c90ce9b87f7b0e61ebee07201b8ba9ebe7d5523..cc41d434333bdb3c1c89ae8eb5096172d5811c03 100644 --- a/sqlite/_dumps/structs.sql +++ b/sqlite/_dumps/structs.sql @@ -47,10 +47,10 @@ CREATE TABLE data_types ( _rune integer, _bool integer, _string text, - _date text, - _nildate text, - _ptrdate text, - _bytea text, + _date datetime, + _nildate datetime, + _ptrdate datetime, + _defaultdate datetime default current_timestamp, _time text ); diff --git a/sqlite/_example/main.go b/sqlite/_example/main.go index 0c5e1541cd4678b9771c872e7bcd06a86bb3161e..905961402bc419fe43dc6d779620672fde9bd7ed 100644 --- a/sqlite/_example/main.go +++ b/sqlite/_example/main.go @@ -5,8 +5,8 @@ import ( "log" "time" - "upper.io/db" // Imports the main db package. - _ "upper.io/db/sqlite" // Imports the sqlite adapter. + "upper.io/v2/db" // Imports the main db package. + _ "upper.io/v2/db/sqlite" // Imports the sqlite adapter. ) var settings = db.Settings{ diff --git a/sqlite/collection.go b/sqlite/collection.go index b1d8e698ec2e701998b3b37ec037a610c3b5b5b8..aa9c63d30b4aae492f289b36fdd2c2bc0e9e0fdc 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,240 +22,59 @@ package sqlite import ( - "fmt" - "reflect" "strings" - "time" "database/sql" - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source - names []string -} - -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args + *database } -func interfaceArgs(value interface{}) (args []interface{}) { - - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) +var _ = db.Collection(&table{}) - for i = 0; i < total; i++ { - args[i] = toInternal(v.Index(i).Interface()) - } - - return args - } - return nil - default: - args = []interface{}{toInternal(value)} - } - - return args +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} - -func (c *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: c, - where: where, - arguments: arguments, - } - - return result -} - -func (c *table) tableN(i int) string { - if len(c.names) > i { - chunks := strings.SplitN(c.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. -func (c *table) Truncate() error { - - _, err := c.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{c.tableN(0)}, +// Truncate deletes all rows from the table. +func (t *table) Truncate() error { + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. -func (c *table) Append(item interface{}) (interface{}, error) { - +// Append inserts an item (map or struct) into the collection. +func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - cols, vals, err := c.FieldValues(item, toInternal) + columnNames, columnValues, err := t.FieldValues(item) - // Error ocurred, stop appending. if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = c.source.getPrimaryKey(c.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -263,14 +82,14 @@ func (c *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{c.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = c.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -297,10 +116,10 @@ func (c *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -323,72 +142,15 @@ func (c *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. -func (c *table) Exists() bool { - if err := c.source.tableExists(c.names...); err != nil { +// Exists returns true if the collection exists. +func (t *table) Exists() bool { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } -func (c *table) Name() string { - return strings.Join(c.names, `, `) -} - -// Converts a Go value into internal database representation. -func toInternal(val interface{}) interface{} { - switch t := val.(type) { - case db.Marshaler: - return t - case []byte: - return string(t) - case *time.Time: - if t == nil || t.IsZero() { - return sqlgen.Value{sqlgen.Raw{sqlNull}} - } - return t.Format(DateFormat) - case time.Time: - if t.IsZero() { - return sqlgen.Value{sqlgen.Raw{sqlNull}} - } - return t.Format(DateFormat) - case time.Duration: - return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond) - case sql.NullBool: - if t.Valid { - if t.Bool { - return toInternal(t.Bool) - } - return false - } - return sqlgen.Value{sqlgen.Raw{sqlNull}} - case sql.NullFloat64: - if t.Valid { - if t.Float64 != 0.0 { - return toInternal(t.Float64) - } - return float64(0) - } - return sqlgen.Value{sqlgen.Raw{sqlNull}} - case sql.NullInt64: - if t.Valid { - if t.Int64 != 0 { - return toInternal(t.Int64) - } - return 0 - } - return sqlgen.Value{sqlgen.Raw{sqlNull}} - case sql.NullString: - if t.Valid { - return toInternal(t.String) - } - return sqlgen.Value{sqlgen.Raw{sqlNull}} - case bool: - if t == true { - return `1` - } - return `0` - } - - return to.String(val) +// Name returns the name of the table or tables that form the collection. +func (t *table) Name() string { + return strings.Join(t.Tables, `, `) } diff --git a/sqlite/database.go b/sqlite/database.go index 080e95a5e4671aad07b1bff28b6254e981beb186..71f76e0761198ba464bf001e39d91e0a740c1a37 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -24,40 +24,26 @@ package sqlite import ( "database/sql" "fmt" - "os" "strings" "time" - // Importing SQLite3 driver. - _ "github.com/mattn/go-sqlite3" - "upper.io/cache" - "upper.io/db" - "upper.io/db/util/schema" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) -const ( - // Adapter is the public name of the adapter. - Adapter = `sqlite` + "github.com/jmoiron/sqlx" + _ "github.com/mattn/go-sqlite3" // SQLite3 driver. + "upper.io/v2/db" + "upper.io/v2/db/util/schema" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" + "upper.io/v2/db/util/sqlutil/tx" ) var ( - // DateFormat defines the format used for storing dates. - DateFormat = "2006-01-02 15:04:05" - // TimeFormat defines the format used for storing time values. - TimeFormat = "%d:%02d:%02d.%d" -) - -var template *sqlgen.Template - -var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL - session *sql.DB - tx *tx + session *sqlx.DB + tx *sqltx.Tx schema *schema.DatabaseSchema // columns property was introduced so we could query PRAGMA data only once // and retrieve all the column information we'd need, such as name and if it @@ -65,397 +51,392 @@ type source struct { columns map[string][]columnSchemaT } -type columnSchemaT struct { - Name string `db:"name"` - PK int `db:"pk"` +type tx struct { + *sqltx.Tx + *database } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() - } +type columnSchemaT struct { + Name string `db:"name"` + PK int `db:"pk"` } -func init() { - - template = &sqlgen.Template{ - sqlColumnSeparator, - sqlIdentifierSeparator, - sqlIdentifierQuote, - sqlValueSeparator, - sqlValueQuote, - sqlAndKeyword, - sqlOrKeyword, - sqlNotKeyword, - sqlDescKeyword, - sqlAscKeyword, - sqlDefaultOperator, - sqlClauseGroup, - sqlClauseOperator, - sqlColumnValue, - sqlTableAliasLayout, - sqlColumnAliasLayout, - sqlSortByColumnLayout, - sqlWhereLayout, - sqlOrderByLayout, - sqlInsertLayout, - sqlSelectLayout, - sqlUpdateLayout, - sqlDeleteLayout, - sqlTruncateLayout, - sqlDropDatabaseLayout, - sqlDropTableLayout, - sqlSelectCountLayout, - sqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -func (s *source) populateSchema() (err error) { - var collections []string +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error - s.schema = schema.NewDatabaseSchema() + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + Database: settings.Database, + Options: map[string]string{ + "cache": "shared", + }, + } - var conn ConnectionURL + d.connURL = conn + } - if conn, err = ParseURL(s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`sqlite3`, d.connURL.String()); err != nil { return err } - s.schema.Name = conn.Database + d.session.Mapper = sqlutil.NewMapper() - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if err = d.populateSchema(); err != nil { return err } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err - } - } + return nil +} - return err +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { - var query string - var res sql.Result - var err error - var start, end int64 +func (d *database) clone() (*database, error) { + src := &database{} + src.Setup(d.connURL) - start = time.Now().UnixNano() + if err := src.Open(); err != nil { + return nil, err + } - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + return src, nil +} - if s.session == nil { - return nil, db.ErrNotConnected +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() +} + +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() } + return nil +} - query = stmt.Compile(template) +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) - } else { - res, err = s.session.Exec(query, args...) + if len(names) == 0 { + return nil, db.ErrMissingCollectionName } - return res, err -} + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { - var rows *sql.Rows - var query string - var err error - var start, end int64 + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper - start = time.Now().UnixNano() + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } - if s.session == nil { - return nil, db.ErrNotConnected - } + tableName := chunks[0] - query = stmt.Compile(template) + if err := d.tableExists(tableName); err != nil { + return nil, err + } - if s.tx != nil { - rows, err = s.tx.sqlTx.Query(query, args...) - } else { - rows, err = s.session.Query(query, args...) + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } } - return rows, err + return col, nil } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { - var query string - var row *sql.Row - var err error - var start, end int64 - - start = time.Now().UnixNano() +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + tablesInSchema := len(d.schema.Tables) - if s.session == nil { - return nil, db.ErrNotConnected + // Id.schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil } - query = stmt.Compile(template) + // Schema is empty. - if s.tx != nil { - row = s.tx.sqlTx.QueryRow(query, args...) - } else { - row = s.session.QueryRow(query, args...) + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`tbl_name`), + ), + Table: sqlgen.TableWithName(`sqlite_master`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`type`), + Operator: `=`, + Value: sqlgen.NewValue(`table`), + }, + ), } - return row, err -} + // Executing statement. + var rows *sqlx.Rows + if rows, err = d.Query(stmt); err != nil { + return nil, err + } -func (s *source) doRawQuery(query string, args ...interface{}) (*sql.Rows, error) { - var rows *sql.Rows - var err error - var start, end int64 + defer rows.Close() - start = time.Now().UnixNano() + collections = []string{} - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + var name string - if s.session == nil { - return nil, db.ErrNotConnected - } + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err + } - if s.tx != nil { - rows, err = s.tx.sqlTx.Query(query, args...) - } else { - rows, err = s.session.Query(query, args...) + // Adding table entry to schema. + d.schema.AddTable(name) + + // Adding table to collections array. + collections = append(collections, name) } - return rows, err + return collections, nil } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err + } + + conn.Database = database + + d.connURL = conn + + return d.Open() } -func (s *source) clone() (*source, error) { - src := &source{} - src.Setup(s.connURL) +// Drop removes all tables from the current database. +func (d *database) Drop() error { - if err := src.Open(); err != nil { - return nil, err - } + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) - return src, nil + return err } -func (s *source) Clone() (db.Database, error) { - return s.clone() +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() } -func (s *source) Transaction() (db.Tx, error) { +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name +} + +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { var err error - var clone *source - var sqlTx *sql.Tx + var clone *database + var sqlTx *sqlx.Tx - if sqlTx, err = s.session.Begin(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = s.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(conn db.ConnectionURL) error { - s.connURL = conn - return s.Open() -} + clone.tx = sqltx.New(sqlTx) -// Returns the underlying *sql.DB instance. -func (s *source) Driver() interface{} { - return s.session + return tx{Tx: clone.tx, database: clone}, nil } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { + var query string + var res sql.Result var err error + var start, end int64 - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - Database: settings.Database, - Options: map[string]string{ - "cache": "shared", - }, - } + start = time.Now().UnixNano() - s.connURL = conn - } + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - if s.session, err = sql.Open(`sqlite3`, s.connURL.String()); err != nil { - return err + if d.session == nil { + return nil, db.ErrNotConnected } - if err = s.populateSchema(); err != nil { - return err + query = stmt.Compile(template.Template) + + if d.tx != nil { + res, err = d.tx.Exec(query, args...) + } else { + res, err = d.session.Exec(query, args...) } - return nil + return res, err } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows + var query string + var err error + var start, end int64 -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL + start = time.Now().UnixNano() - if conn, err = ParseURL(s.connURL.String()); err != nil { - return err + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() + + if d.session == nil { + return nil, db.ErrNotConnected } - conn.Database = database + query = stmt.Compile(template.Template) - s.connURL = conn + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) + } else { + rows, err = d.session.Queryx(query, args...) + } - return s.Open() + return rows, err } -// Drops the currently active database. -func (s *source) Drop() error { - return db.ErrUnsupported -} +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { + var query string + var row *sqlx.Row + var err error + var start, end int64 -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { + start = time.Now().UnixNano() - tablesInSchema := len(s.schema.Tables) + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil + if d.session == nil { + return nil, db.ErrNotConnected } - // Schema is empty. + query = stmt.Compile(template.Template) - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`tbl_name`}, - }, - Table: sqlgen.Table{`sqlite_master`}, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`type`}, - `=`, - sqlgen.Value{`table`}, - }, - }, + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) + } else { + row = d.session.QueryRowx(query, args...) } - // Executing statement. - var rows *sql.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err - } + return row, err +} - defer rows.Close() +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - collections = []string{} + d.schema = schema.NewDatabaseSchema() - var name string + var conn ConnectionURL - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err - } + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err + } - // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.Name = conn.Database - // Adding table to collections array. - collections = append(collections, name) + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { + return err } - return collections, nil + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err + } + } + + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error - var rows *sql.Rows + var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`sqlite_master`}, - Columns: sqlgen.Columns{ - {`tbl_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`type`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`tbl_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`sqlite_master`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`tbl_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`type`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`tbl_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, `table`, names[i]); err != nil { + if rows, err = d.Query(stmt, `table`, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -469,10 +450,10 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil @@ -480,10 +461,10 @@ func (s *source) tableColumns(tableName string) ([]string, error) { q := fmt.Sprintf(`PRAGMA TABLE_INFO('%s')`, tableName) - rows, err := s.doRawQuery(q) + rows, err := d.doRawQuery(q) - if s.columns == nil { - s.columns = make(map[string][]columnSchemaT) + if d.columns == nil { + d.columns = make(map[string][]columnSchemaT) } columns := []columnSchemaT{} @@ -492,81 +473,64 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.columns[tableName] = columns + d.columns[tableName] = columns - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(s.columns)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(d.columns)) - for i := range s.columns[tableName] { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, s.columns[tableName][i].Name) + for i := range d.columns[tableName] { + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, d.columns[tableName][i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error +func (d *database) getPrimaryKey(tableName string) ([]string, error) { + tableSchema := d.schema.Table(tableName) - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } + d.tableColumns(tableName) - col := &table{ - source: s, - names: names, - } - - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) + maxValue := -1 - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName + for i := range d.columns[tableName] { + if d.columns[tableName][i].PK > 0 && d.columns[tableName][i].PK > maxValue { + maxValue = d.columns[tableName][i].PK } + } - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } + if maxValue > 0 { + tableSchema.PrimaryKey = make([]string, maxValue) - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err + for i := range d.columns[tableName] { + if d.columns[tableName][i].PK > 0 { + tableSchema.PrimaryKey[d.columns[tableName][i].PK-1] = d.columns[tableName][i].Name + } } } - return col, nil + return tableSchema.PrimaryKey, nil } -// getPrimaryKey returns the names of the columns that define the primary key -// of the table. -func (s *source) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) +func (d *database) doRawQuery(query string, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows + var err error + var start, end int64 - s.tableColumns(tableName) + start = time.Now().UnixNano() - maxValue := -1 + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - for i := range s.columns[tableName] { - if s.columns[tableName][i].PK > 0 && s.columns[tableName][i].PK > maxValue { - maxValue = s.columns[tableName][i].PK - } + if d.session == nil { + return nil, db.ErrNotConnected } - if maxValue > 0 { - tableSchema.PrimaryKey = make([]string, maxValue) - - for i := range s.columns[tableName] { - if s.columns[tableName][i].PK > 0 { - tableSchema.PrimaryKey[s.columns[tableName][i].PK-1] = s.columns[tableName][i].Name - } - } + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) + } else { + rows, err = d.session.Queryx(query, args...) } - return tableSchema.PrimaryKey, nil + return rows, err } diff --git a/sqlite/database_test.go b/sqlite/database_test.go index 9de16cc04f11902f9d5ac23690eb4befa722556f..ba54503e0381b4681316c3177201d8a222b607cc 100644 --- a/sqlite/database_test.go +++ b/sqlite/database_test.go @@ -31,6 +31,7 @@ package sqlite import ( "database/sql" "errors" + "fmt" "math/rand" "os" "reflect" @@ -39,43 +40,49 @@ import ( "testing" "time" + "github.com/jmoiron/sqlx" "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util/sqlutil" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlutil" ) const ( - database = `_dumps/gotest.sqlite3.db` + databaseName = `_dumps/gotest.sqlite3.db` +) + +const ( + testTimeZone = "Canada/Eastern" ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, } // Structure for testing conversions and datatypes. type testValuesStruct struct { - Uint uint `field:"_uint"` - Uint8 uint8 `field:"_uint8"` - Uint16 uint16 `field:"_uint16"` - Uint32 uint32 `field:"_uint32"` - Uint64 uint64 `field:"_uint64"` - - Int int `field:"_int"` - Int8 int8 `field:"_int8"` - Int16 int16 `field:"_int16"` - Int32 int32 `field:"_int32"` - Int64 int64 `field:"_int64"` - - Float32 float32 `field:"_float32"` - Float64 float64 `field:"_float64"` - - Bool bool `field:"_bool"` - String string `field:"_string"` - - Date time.Time `field:"_date"` - DateN *time.Time `field:"_nildate"` - DateP *time.Time `field:"_ptrdate"` - Time time.Duration `field:"_time"` + Uint uint `db:"_uint"` + Uint8 uint8 `db:"_uint8"` + Uint16 uint16 `db:"_uint16"` + Uint32 uint32 `db:"_uint32"` + Uint64 uint64 `db:"_uint64"` + + Int int `db:"_int"` + Int8 int8 `db:"_int8"` + Int16 int16 `db:"_int16"` + Int32 int32 `db:"_int32"` + Int64 int64 `db:"_int64"` + + Float32 float32 `db:"_float32"` + Float64 float64 `db:"_float64"` + + Bool bool `db:"_bool"` + String string `db:"_string"` + + Date time.Time `db:"_date"` + DateN *time.Time `db:"_nildate"` + DateP *time.Time `db:"_ptrdate"` + DateD *time.Time `db:"_defaultdate,omitempty"` + Time int64 `db:"_time"` } type artistWithInt64Key struct { @@ -114,7 +121,14 @@ func (item *itemWithKey) SetID(keys map[string]interface{}) error { var testValues testValuesStruct func init() { - t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) + loc, err := time.LoadLocation(testTimeZone) + + if err != nil { + panic(err.Error()) + } + + t := time.Date(2011, 7, 28, 1, 2, 3, 0, loc) + tnz := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) testValues = testValuesStruct{ 1, 1, 1, 1, 1, @@ -124,8 +138,9 @@ func init() { "Hello world!", t, nil, - &t, - time.Second * time.Duration(7331), + &tnz, + nil, + int64(time.Second * time.Duration(7331)), } } @@ -152,7 +167,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, } // Opening database. @@ -533,39 +548,10 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an struct with no tags. - rowStruct := struct { - ID uint64 - Name string - }{} - - res = artist.Find() - - for { - err = res.Next(&rowStruct) - - if err == db.ErrNoMoreRows { - break - } - - if err == nil { - if rowStruct.ID == 0 { - t.Fatalf("Expecting a not null ID.") - } - if rowStruct.Name == "" { - t.Fatalf("Expecting a name.") - } - } else { - t.Fatal(err) - } - } - - res.Close() - // Dumping into a tagged struct. rowStruct2 := struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find() @@ -591,7 +577,7 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an slice of maps. + // Dumping into a slice of maps. allRowsMap := []map[string]interface{}{} res = artist.Find() @@ -609,11 +595,10 @@ func TestResultFetch(t *testing.T) { } } - // Dumping into an slice of structs. - + // Dumping into a slice of structs. allRowsStruct := []struct { - ID uint64 - Name string + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` }{} res = artist.Find() @@ -633,8 +618,8 @@ func TestResultFetch(t *testing.T) { // Dumping into an slice of tagged structs. allRowsStruct2 := []struct { - Value1 uint64 `field:"id"` - Value2 string `field:"name"` + Value1 uint64 `db:"id"` + Value2 string `db:"name"` }{} res = artist.Find() @@ -672,8 +657,8 @@ func TestUpdate(t *testing.T) { // Defining destination struct value := struct { - ID uint64 - Name string + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` }{} // Getting the first artist. @@ -704,7 +689,7 @@ func TestUpdate(t *testing.T) { // Updating set with a struct rowStruct := struct { - Name string + Name string `db:"name"` }{strings.ToLower(value.Name)} if err = res.Update(rowStruct); err != nil { @@ -1036,9 +1021,9 @@ func TestRawRelations(t *testing.T) { func TestRawQuery(t *testing.T) { var sess db.Database - var rows *sql.Rows + var rows *sqlx.Rows var err error - var drv *sql.DB + var drv *sqlx.DB type publicationType struct { ID int64 `db:"id,omitempty"` @@ -1052,9 +1037,9 @@ func TestRawQuery(t *testing.T) { defer sess.Close() - drv = sess.Driver().(*sql.DB) + drv = sess.Driver().(*sqlx.DB) - rows, err = drv.Query(` + rows, err = drv.Queryx(` SELECT p.id, p.title AS publication_title, @@ -1349,10 +1334,38 @@ func TestDataTypes(t *testing.T) { // Trying to dump the subject into an empty structure of the same type. var item testValuesStruct - res.One(&item) + if err = res.One(&item); err != nil { + t.Fatal(err) + } + + if item.DateD == nil { + t.Fatal("Expecting default date to have been set on append") + } + + // Copy the default date (this value is set by the database) + testValues.DateD = item.DateD + + loc, _ := time.LoadLocation(testTimeZone) + item.Date = item.Date.In(loc) + + // TODO: Try to guess this conversion. + if item.DateP.Location() != testValues.DateP.Location() { + v := item.DateP.In(testValues.DateP.Location()) + item.DateP = &v + } // The original value and the test subject must match. - if reflect.DeepEqual(item, testValues) == false { + if !reflect.DeepEqual(item, testValues) { + fmt.Printf("item1: %#v\n", item) + fmt.Printf("test2: %#v\n", testValues) + fmt.Printf("item1: %#v\n", item.Date.String()) + fmt.Printf("test2: %#v\n", testValues.Date.String()) + fmt.Printf("item1: %v\n", item.Date.Location().String()) + fmt.Printf("test2: %v\n", testValues.Date.Location().String()) + fmt.Printf("item1: %#v\n", item.DateP) + fmt.Printf("test2: %#v\n", testValues.DateP) + fmt.Printf("item1: %v\n", item.DateP.Location().String()) + fmt.Printf("test2: %v\n", testValues.DateP.Location().String()) t.Fatalf("Struct is different.") } } @@ -1373,7 +1386,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec(`DELETE FROM "artist"`); err != nil { b.Fatal(err) @@ -1427,7 +1440,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/sqlite/result.go b/sqlite/result.go deleted file mode 100644 index c3a84226905e1dd0551da1d23ad7d27762fbcaf1..0000000000000000000000000000000000000000 --- a/sqlite/result.go +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package sqlite - -import ( - "database/sql" - "fmt" - "strings" - - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"_t"` -} - -type result struct { - table *table - cursor *sql.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) error { - - var err error - - // Current cursor. - err = r.setCursor() - - if err != nil { - r.Close() - } - - // Fetching the next result from the cursor. - err = sqlutil.FetchRow(r.cursor, dst) - - if err != nil { - r.Close() - } - - return err -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values, toInternal) - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() error { - var err error - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counting the elements that will be returned. -func (r *result) Count() (uint64, error) { - var count counter - - rows, err := r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - defer rows.Close() - if err = sqlutil.FetchRow(rows, &count); err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go new file mode 100644 index 0000000000000000000000000000000000000000..8ad04fa82325fbfd599cdf1ddb527fb800e6e0a0 --- /dev/null +++ b/sqlite/sqlite.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sqlite // import "upper.io/v2/db/sqlite" + +import ( + "upper.io/cache" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `sqlite` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/sqlite/layout.go b/sqlite/template.go similarity index 65% rename from sqlite/layout.go rename to sqlite/template.go index 9d937696e41b9247a817b61f8507774850b62424..abea16614305df659ce23b2b33a84df02a01da90 100644 --- a/sqlite/layout.go +++ b/sqlite/template.go @@ -22,37 +22,38 @@ package sqlite const ( - sqlColumnSeparator = `.` - sqlIdentifierSeparator = `, ` - sqlIdentifierQuote = `"{{.Raw}}"` - sqlValueSeparator = `, ` - sqlValueQuote = `'{{.}}'` - sqlAndKeyword = `AND` - sqlOrKeyword = `OR` - sqlNotKeyword = `NOT` - sqlDescKeyword = `DESC` - sqlAscKeyword = `ASC` - sqlDefaultOperator = `=` - sqlClauseGroup = `({{.}})` - sqlClauseOperator = ` {{.}} ` - sqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - sqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - sqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - sqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - sqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - sqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - sqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -75,24 +76,24 @@ const ( {{if .Offset}} {{if not .Limit}} - LIMIT -1 + LIMIT -1 {{end}} OFFSET {{.Offset}} {{end}} ` - sqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - sqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - sqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -104,13 +105,13 @@ const ( {{if .Offset}} {{if not .Limit}} - LIMIT -1 + LIMIT -1 {{end}} OFFSET {{.Offset}} {{end}} ` - sqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -118,23 +119,21 @@ const ( {{.Extra}} ` - sqlTruncateLayout = ` + adapterTruncateLayout = ` DELETE FROM {{.Table}} ` - sqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - sqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - sqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - sqlNull = `NULL` ) diff --git a/sqlite/tx.go b/sqlite/tx.go deleted file mode 100644 index 9f4c622e66315c5bc1bc286dedb67a9584e50a07..0000000000000000000000000000000000000000 --- a/sqlite/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package sqlite - -import ( - "database/sql" -) - -type tx struct { - *source - sqlTx *sql.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/util/main.go b/util/main.go deleted file mode 100644 index fb82ba9914920fdad1f7a5a946e37db41d207bd9..0000000000000000000000000000000000000000 --- a/util/main.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package util - -import ( - "reflect" - "regexp" - "strings" - "time" - - "menteslibres.net/gosexy/to" -) - -var reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`) - -var ( - durationType = reflect.TypeOf(time.Duration(0)) - timeType = reflect.TypeOf(time.Time{}) - ptimeType = reflect.TypeOf(&time.Time{}) -) - -type tagOptions map[string]bool - -func parseTagOptions(s string) tagOptions { - opts := make(tagOptions) - chunks := strings.Split(s, `,`) - for _, chunk := range chunks { - opts[strings.TrimSpace(chunk)] = true - } - return opts -} - -// ParseTag splits a struct tag into comma separated chunks. The first chunk is -// returned as a string value, remaining chunks are considered enabled options. -func ParseTag(tag string) (string, tagOptions) { - // Based on http://golang.org/src/pkg/encoding/json/tags.go - if i := strings.Index(tag, `,`); i != -1 { - return tag[:i], parseTagOptions(tag[i+1:]) - } - return tag, parseTagOptions(``) -} - -// GetStructFieldIndex returns the struct field index for a given column name -// or nil, if no column matches. -func GetStructFieldIndex(t reflect.Type, columnName string) []int { - - n := t.NumField() - - for i := 0; i < n; i++ { - - field := t.Field(i) - - if field.PkgPath != `` { - // Field is unexported. - continue - } - - // Attempt to use db:`column_name` - fieldName, fieldOptions := ParseTag(field.Tag.Get(`db`)) - - // Deprecated `field` tag. - if deprecatedField := field.Tag.Get(`field`); deprecatedField != `` { - fieldName = deprecatedField - } - - // Deprecated `inline` tag. - if deprecatedInline := field.Tag.Get(`inline`); deprecatedInline != `` { - fieldOptions[`inline`] = true - } - - // Skipping field - if fieldName == `-` { - continue - } - - // Trying to match field name. - - // Explicit JSON or BSON options. - if fieldName == `` && fieldOptions[`bson`] { - // Using name from the BSON tag. - fieldName, _ = ParseTag(field.Tag.Get(`bson`)) - } - - if fieldName == `` && fieldOptions[`bson`] { - // Using name from the JSON tag. - fieldName, _ = ParseTag(field.Tag.Get(`bson`)) - } - - // Still don't have a match? try to match againt JSON. - if fieldName == `` { - fieldName, _ = ParseTag(field.Tag.Get(`json`)) - } - - // Still don't have a match? try to match againt BSON. - if fieldName == `` { - fieldName, _ = ParseTag(field.Tag.Get(`bson`)) - } - - // Attempt to match field name. - if fieldName == columnName { - return []int{i} - } - - // Nothing works, trying to match by name. - if fieldName == `` { - if NormalizeColumn(field.Name) == NormalizeColumn(columnName) { - return []int{i} - } - } - - // Inline option. - if fieldOptions[`inline`] == true { - index := GetStructFieldIndex(field.Type, columnName) - if index != nil { - res := append([]int{i}, index...) - return res - } - } - } - // No match. - return nil -} - -// StringToType converts a string value into another type. -func StringToType(src string, dstt reflect.Type) (srcv reflect.Value, err error) { - - // Is destination a pointer? - if dstt.Kind() == reflect.Ptr { - if src == "" { - return - } - } - - switch dstt { - case durationType: - srcv = reflect.ValueOf(to.Duration(src)) - case timeType: - srcv = reflect.ValueOf(to.Time(src)) - case ptimeType: - t := to.Time(src) - srcv = reflect.ValueOf(&t) - default: - return StringToKind(src, dstt.Kind()) - } - return srcv, nil -} - -// StringToKind converts a string into a kind. -func StringToKind(src string, dstk reflect.Kind) (reflect.Value, error) { - var srcv reflect.Value - - // Destination type. - switch dstk { - case reflect.Interface: - // Destination is interface, nuff said. - srcv = reflect.ValueOf(src) - default: - cv, err := to.Convert(src, dstk) - if err != nil { - return srcv, nil - } - srcv = reflect.ValueOf(cv) - } - - return srcv, nil -} - -// NormalizeColumn prepares a column for comparison against another column. -func NormalizeColumn(s string) string { - return strings.ToLower(reColumnCompareExclude.ReplaceAllString(s, "")) -} diff --git a/util/schema/main.go b/util/schema/schema.go similarity index 100% rename from util/schema/main.go rename to util/schema/schema.go diff --git a/util/sqlgen/benchmark_test.go b/util/sqlgen/benchmark_test.go deleted file mode 100644 index 923950352ff8da46fdd6d2026175fcfb36250149..0000000000000000000000000000000000000000 --- a/util/sqlgen/benchmark_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package sqlgen - -import ( - "fmt" - "math/rand" - "testing" -) - -func BenchmarkColumn(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Column{"a"} - } -} - -func BenchmarkCompileColumn(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Column{Value: "a"}.Compile(defaultTemplate) - } -} - -func BenchmarkColumns(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Columns{{"a"}, {"b"}, {"c"}} - } -} - -func BenchmarkCompileColumns(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Columns{{"a"}, {"b"}, {"c"}}.Compile(defaultTemplate) - } -} - -func BenchmarkValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Value{"a"} - } -} - -func BenchmarkCompileValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Value{"a"}.Compile(defaultTemplate) - } -} - -func BenchmarkValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Values{{"a"}, {"b"}, {"c"}, {1}, {2}, {3}} - } -} - -func BenchmarkCompileValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Values{{"a"}, {"b"}, {"c"}, {1}, {2}, {3}}.Compile(defaultTemplate) - } -} - -func BenchmarkDatabase(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Database{"TestDatabase"} - } -} - -func BenchmarkCompileDatabase(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Database{"TestDatabase"}.Compile(defaultTemplate) - } -} - -func BenchmarkValueRaw(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Value{Raw{"a"}} - } -} - -func BenchmarkColumnValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}} - } -} - -func BenchmarkCompileColumnValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}}.Compile(defaultTemplate) - } -} - -func BenchmarkColumnValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValues{{Column{"a"}, "=", Value{Raw{"7"}}}} - } -} - -func BenchmarkCompileColumnValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValues{{Column{"a"}, "=", Value{Raw{"7"}}}}.Compile(defaultTemplate) - } -} - -func BenchmarkOrderBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = OrderBy{ - SortColumns: SortColumns{ - SortColumn{Column: Column{"foo"}}, - }, - } - } -} - -func BenchmarkCompileOrderBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = OrderBy{ - SortColumns: SortColumns{ - SortColumn{Column: Column{"foo"}}, - }, - }.Compile(defaultTemplate) - } -} - -func BenchmarkGroupBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = GroupBy{ - Column{"foo"}, - } - } -} - -func BenchmarkCompileGroupBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = GroupBy{ - Column{"foo"}, - }.Compile(defaultTemplate) - } -} - -func BenchmarkWhere(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - } - } -} - -func BenchmarkCompileWhere(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }.Compile(defaultTemplate) - } -} - -func BenchmarkTable(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Table{"foo"} - } -} - -func BenchmarkCompileTable(b *testing.B) { - var t string - for i := 0; i < b.N; i++ { - t = Table{"foo"}.Compile(defaultTemplate) - if t != `"foo"` { - b.Fatal("Caching failed.") - } - } -} - -func BenchmarkCompileRandomTable(b *testing.B) { - var t string - var m, n int - var s, e string - - for i := 0; i < b.N; i++ { - m, n = rand.Int(), rand.Int() - s = fmt.Sprintf(`%s as %s`, m, n) - e = fmt.Sprintf(`"%s" AS "%s"`, m, n) - - t = Table{s}.Compile(defaultTemplate) - if t != e { - b.Fatal() - } - } -} - -func BenchmarkCompileSelect(b *testing.B) { - var stmt Statement - - for i := 0; i < b.N; i++ { - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}}, - }, - } - _ = stmt.Compile(defaultTemplate) - } -} diff --git a/util/sqlgen/column.go b/util/sqlgen/column.go index 6a69c98de0ec65b3126cbba9d49754f78c5f573d..4eb395e173a4d2634bbe0c2519508c07f891bfed 100644 --- a/util/sqlgen/column.go +++ b/util/sqlgen/column.go @@ -5,41 +5,58 @@ import ( "strings" ) -type column_t struct { +type columnT struct { Name string Alias string } +// Column represents a SQL column. type Column struct { - Value interface{} + Name interface{} + hash string } -func (self Column) Hash() string { - switch t := self.Value.(type) { - case cc: - return `Column(` + t.Hash() + `)` - case string: - return `Column(` + t + `)` +// ColumnWithName creates and returns a Column with the given name. +func ColumnWithName(name string) *Column { + return &Column{Name: name} +} + +// Hash returns a unique identifier. +func (c *Column) Hash() string { + if c.hash == "" { + var s string + + switch t := c.Name.(type) { + case Fragment: + s = t.Hash() + case fmt.Stringer: + s = t.String() + case string: + s = t + default: + s = fmt.Sprintf("%v", c.Name) + } + + c.hash = fmt.Sprintf(`Column{Name:%q}`, s) } - return fmt.Sprintf(`Column(%v)`, self.Value) + + return c.hash } -func (self Column) Compile(layout *Template) (compiled string) { +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *Column) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - switch value := self.Value.(type) { + switch value := c.Name.(type) { case string: - // input := strings.TrimSpace(value) input := trimString(value) - //chunks := reAliasSeparator.Split(input, 2) chunks := separateByAS(input) if len(chunks) == 1 { - //chunks = reSpaceSeparator.Split(input, 2) chunks = separateBySpace(input) } @@ -48,9 +65,8 @@ func (self Column) Compile(layout *Template) (compiled string) { nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) for i := range nameChunks { - // nameChunks[i] = strings.TrimSpace(nameChunks[i]) nameChunks[i] = trimString(nameChunks[i]) - nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{nameChunks[i]}) + nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) } name = strings.Join(nameChunks, layout.ColumnSeparator) @@ -58,19 +74,18 @@ func (self Column) Compile(layout *Template) (compiled string) { var alias string if len(chunks) > 1 { - // alias = strings.TrimSpace(chunks[1]) alias = trimString(chunks[1]) - alias = mustParse(layout.IdentifierQuote, Raw{alias}) + alias = mustParse(layout.IdentifierQuote, Raw{Value: alias}) } - compiled = mustParse(layout.ColumnAliasLayout, column_t{name, alias}) + compiled = mustParse(layout.ColumnAliasLayout, columnT{name, alias}) case Raw: compiled = value.String() default: - compiled = fmt.Sprintf("%v", self.Value) + compiled = fmt.Sprintf("%v", c.Name) } - layout.Write(self, compiled) + layout.Write(c, compiled) return } diff --git a/util/sqlgen/column_test.go b/util/sqlgen/column_test.go index 62f3929def3ffcc912e5312bab6affcc4ad5d1c3..e5fdb7d825644a8e76b249da7e65fe20d225b336 100644 --- a/util/sqlgen/column_test.go +++ b/util/sqlgen/column_test.go @@ -1,13 +1,27 @@ package sqlgen import ( + "fmt" "testing" ) +func TestColumnHash(t *testing.T) { + var s, e string + + column := Column{Name: "role.name"} + + s = column.Hash() + e = fmt.Sprintf(`Column{Name:"%s"}`, column.Name) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestColumnString(t *testing.T) { var s, e string - column := Column{"role.name"} + column := Column{Name: "role.name"} s = column.Compile(defaultTemplate) e = `"role"."name"` @@ -20,7 +34,7 @@ func TestColumnString(t *testing.T) { func TestColumnAs(t *testing.T) { var s, e string - column := Column{"role.name as foo"} + column := Column{Name: "role.name as foo"} s = column.Compile(defaultTemplate) e = `"role"."name" AS "foo"` @@ -33,7 +47,7 @@ func TestColumnAs(t *testing.T) { func TestColumnImplicitAs(t *testing.T) { var s, e string - column := Column{"role.name foo"} + column := Column{Name: "role.name foo"} s = column.Compile(defaultTemplate) e = `"role"."name" AS "foo"` @@ -46,7 +60,7 @@ func TestColumnImplicitAs(t *testing.T) { func TestColumnRaw(t *testing.T) { var s, e string - column := Column{Raw{"role.name As foo"}} + column := Column{Name: Raw{Value: "role.name As foo"}} s = column.Compile(defaultTemplate) e = `role.name As foo` @@ -55,3 +69,51 @@ func TestColumnRaw(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkColumnWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ColumnWithName("a") + } +} + +func BenchmarkColumnHash(b *testing.B) { + c := Column{Name: "name"} + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnCompile(b *testing.B) { + c := Column{Name: "name"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Column{Name: "name"} + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithDotCompile(b *testing.B) { + c := Column{Name: "role.name"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name foo"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name AS foo"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/column_value.go b/util/sqlgen/column_value.go index fbe57987c43aec83b0ed24b42cb15156dd82c3f9..485a46f35d09dfd3efd7132d2c03cf091a421d88 100644 --- a/util/sqlgen/column_value.go +++ b/util/sqlgen/column_value.go @@ -1,71 +1,93 @@ package sqlgen import ( + "fmt" "strings" ) +// ColumnValue represents a bundle between a column and a corresponding value. type ColumnValue struct { - Column + Column Fragment Operator string - Value + Value Fragment + hash string } -type columnValue_s struct { +type columnValueT struct { Column string Operator string Value string } -func (self ColumnValue) Hash() string { - return `ColumnValue(` + self.Column.Hash() + `;` + self.Operator + `;` + self.Value.Hash() + `)` +// Hash returns a unique identifier. +func (c *ColumnValue) Hash() string { + if c.hash == "" { + c.hash = fmt.Sprintf(`ColumnValue{Name:%q, Operator:%q, Value:%q}`, c.Column.Hash(), c.Operator, c.Value.Hash()) + } + return c.hash } -func (self ColumnValue) Compile(layout *Template) (compiled string) { +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *ColumnValue) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - data := columnValue_s{ - self.Column.Compile(layout), - self.Operator, - self.Value.Compile(layout), + data := columnValueT{ + c.Column.Compile(layout), + c.Operator, + c.Value.Compile(layout), } compiled = mustParse(layout.ColumnValue, data) - layout.Write(self, compiled) + layout.Write(c, compiled) return } -type ColumnValues []ColumnValue +// ColumnValues represents an array of ColumnValue +type ColumnValues struct { + ColumnValues []Fragment + hash string +} + +// JoinColumnValues returns an array of ColumnValue +func JoinColumnValues(values ...Fragment) *ColumnValues { + return &ColumnValues{ColumnValues: values} +} -func (self ColumnValues) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +// Hash returns a unique identifier. +func (c *ColumnValues) Hash() string { + if c.hash == "" { + s := make([]string, len(c.ColumnValues)) + for i := range c.ColumnValues { + s[i] = c.ColumnValues[i].Hash() + } + c.hash = fmt.Sprintf("ColumnValues{ColumnValues:{%s}}", strings.Join(s, ", ")) } - return `ColumnValues(` + strings.Join(hash, `,`) + `)` + return c.hash } -func (self ColumnValues) Compile(layout *Template) (compiled string) { +// Compile transforms the ColumnValues into its SQL representation. +func (c *ColumnValues) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - l := len(self) + l := len(c.ColumnValues) out := make([]string, l) - for i := 0; i < l; i++ { - out[i] = self[i].Compile(layout) + for i := range c.ColumnValues { + out[i] = c.ColumnValues[i].Compile(layout) } compiled = strings.Join(out, layout.IdentifierSeparator) - layout.Write(self, compiled) + layout.Write(c, compiled) return } diff --git a/util/sqlgen/column_value_test.go b/util/sqlgen/column_value_test.go index b535ad1c245c928f97a9060746d9199b60d96253..9a954697805520941033a3366af7a605b979bb7a 100644 --- a/util/sqlgen/column_value_test.go +++ b/util/sqlgen/column_value_test.go @@ -1,14 +1,45 @@ package sqlgen import ( + "fmt" "testing" ) +func TestColumnValueHash(t *testing.T) { + var s, e string + + c := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + + s = c.Hash() + e = fmt.Sprintf(`ColumnValue{Name:%q, Operator:%q, Value:%q}`, c.Column.Hash(), c.Operator, c.Value.Hash()) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestColumnValuesHash(t *testing.T) { + var s, e string + + c := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)}, + &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(2)}, + ) + + s = c.Hash() + + e = fmt.Sprintf(`ColumnValues{ColumnValues:{%s, %s}}`, c.ColumnValues[0].Hash(), c.ColumnValues[1].Hash()) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestColumnValue(t *testing.T) { var s, e string - var cv ColumnValue + var cv *ColumnValue - cv = ColumnValue{Column{"id"}, "=", Value{1}} + cv = &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} s = cv.Compile(defaultTemplate) e = `"id" = '1'` @@ -17,7 +48,7 @@ func TestColumnValue(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } - cv = ColumnValue{Column{"date"}, "=", Value{Raw{"NOW()"}}} + cv = &ColumnValue{Column: ColumnWithName("date"), Operator: "=", Value: NewValue(RawValue("NOW()"))} s = cv.Compile(defaultTemplate) e = `"date" = NOW()` @@ -29,15 +60,14 @@ func TestColumnValue(t *testing.T) { func TestColumnValues(t *testing.T) { var s, e string - var cvs ColumnValues - cvs = ColumnValues{ - {Column{"id"}, ">", Value{8}}, - {Column{"other.id"}, "<", Value{Raw{"100"}}}, - {Column{"name"}, "=", Value{"Haruki Murakami"}}, - {Column{"created"}, ">=", Value{Raw{"NOW()"}}}, - {Column{"modified"}, "<=", Value{Raw{"NOW()"}}}, - } + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) s = cvs.Compile(defaultTemplate) e = `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()` @@ -45,5 +75,82 @@ func TestColumnValues(t *testing.T) { if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) } +} +func BenchmarkNewColumnValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = &ColumnValue{Column: ColumnWithName("a"), Operator: "=", Value: NewValue(Raw{Value: "7"})} + } +} + +func BenchmarkColumnValueHash(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + for i := 0; i < b.N; i++ { + cv.Hash() + } +} + +func BenchmarkColumnValueCompile(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + for i := 0; i < b.N; i++ { + cv.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + cv.Compile(defaultTemplate) + } +} + +func BenchmarkJoinColumnValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + } +} + +func BenchmarkColumnValuesHash(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + for i := 0; i < b.N; i++ { + cvs.Hash() + } +} + +func BenchmarkColumnValuesCompile(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + for i := 0; i < b.N; i++ { + cvs.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + cvs.Compile(defaultTemplate) + } } diff --git a/util/sqlgen/columns.go b/util/sqlgen/columns.go index dfe41f2e38c5133550751bd4d1a3501081e1537c..507cb215f55491a1df1e3c1aa84caea16305e7e0 100644 --- a/util/sqlgen/columns.go +++ b/util/sqlgen/columns.go @@ -1,38 +1,53 @@ package sqlgen import ( + "fmt" "strings" ) -type Columns []Column +// Columns represents an array of Column. +type Columns struct { + Columns []Fragment + hash string +} -func (self Columns) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +// Hash returns a unique identifier. +func (c *Columns) Hash() string { + if c.hash == "" { + s := make([]string, len(c.Columns)) + for i := range c.Columns { + s[i] = c.Columns[i].Hash() + } + c.hash = fmt.Sprintf("Columns{Columns:{%s}}", strings.Join(s, ", ")) } - return `Columns(` + strings.Join(hash, `,`) + `)` + return c.hash +} + +// JoinColumns creates and returns an array of Column. +func JoinColumns(columns ...Fragment) *Columns { + return &Columns{Columns: columns} } -func (self Columns) Compile(layout *Template) (compiled string) { +// Compile transforms the Columns into an equivalent SQL representation. +func (c *Columns) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - l := len(self) + l := len(c.Columns) if l > 0 { out := make([]string, l) for i := 0; i < l; i++ { - out[i] = self[i].Compile(layout) + out[i] = c.Columns[i].Compile(layout) } compiled = strings.Join(out, layout.IdentifierSeparator) } - layout.Write(self, compiled) + layout.Write(c, compiled) return } diff --git a/util/sqlgen/columns_test.go b/util/sqlgen/columns_test.go index 668c0c420c4b72b8c46ab0919f4500ac299dde24..a4f439799f33c766a63ce971364e9d65e8cf1339 100644 --- a/util/sqlgen/columns_test.go +++ b/util/sqlgen/columns_test.go @@ -7,13 +7,13 @@ import ( func TestColumns(t *testing.T) { var s, e string - columns := Columns{ - {"id"}, - {"customer"}, - {"service_id"}, - {"role.name"}, - {"role.id"}, - } + columns := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) s = columns.Compile(defaultTemplate) e = `"id", "customer", "service_id", "role"."name", "role"."id"` @@ -21,5 +21,53 @@ func TestColumns(t *testing.T) { if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) } +} + +func BenchmarkJoinColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkColumnsHash(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnsCompile(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} +func BenchmarkColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + c.Compile(defaultTemplate) + } } diff --git a/util/sqlgen/database.go b/util/sqlgen/database.go index 6c5f731477d0053b264ff9640f03a2288c1e8dbd..df7001dd80c698529ca93371df4186f544ab07c4 100644 --- a/util/sqlgen/database.go +++ b/util/sqlgen/database.go @@ -4,22 +4,34 @@ import ( "fmt" ) +// Database represents a SQL database. type Database struct { - Value string + Name string + hash string } -func (self Database) Hash() string { - return `Database(` + self.Value + `)` +// DatabaseWithName returns a Database with the given name. +func DatabaseWithName(name string) *Database { + return &Database{Name: name} } -func (self Database) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { +// Hash returns a unique identifier. +func (d *Database) Hash() string { + if d.hash == "" { + d.hash = fmt.Sprintf(`Database{Name:%q}`, d.Name) + } + return d.hash +} + +// Compile transforms the Database into an equivalent SQL representation. +func (d *Database) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(d); ok { return c } - compiled = mustParse(layout.IdentifierQuote, Raw{fmt.Sprintf(`%v`, self.Value)}) + compiled = mustParse(layout.IdentifierQuote, Raw{Value: d.Name}) - layout.Write(self, compiled) + layout.Write(d, compiled) return } diff --git a/util/sqlgen/database_test.go b/util/sqlgen/database_test.go new file mode 100644 index 0000000000000000000000000000000000000000..33b1ad8212bcddb7af7618047a3972a27178cada --- /dev/null +++ b/util/sqlgen/database_test.go @@ -0,0 +1,53 @@ +package sqlgen + +import ( + "fmt" + "testing" +) + +func TestDatabaseHash(t *testing.T) { + var s, e string + + column := Database{Name: "users"} + + s = column.Hash() + e = fmt.Sprintf(`Database{Name:"%s"}`, column.Name) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDatabaseCompile(t *testing.T) { + var s, e string + + column := Database{Name: "name"} + + s = column.Compile(defaultTemplate) + e = `"name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkDatabaseHash(b *testing.B) { + c := Database{Name: "name"} + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkDatabaseCompile(b *testing.B) { + c := Database{Name: "name"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkDatabaseCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Database{Name: "name"} + c.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/default.go b/util/sqlgen/default.go index 4b2a3f20922f845b401e3af2416e891b2da8a57b..750dc9b3d4ac77f68e2da1bb1a5477875123af06 100644 --- a/util/sqlgen/default.go +++ b/util/sqlgen/default.go @@ -7,7 +7,7 @@ import ( const ( defaultColumnSeparator = `.` defaultIdentifierSeparator = `, ` - defaultIdentifierQuote = `"{{.Raw}}"` + defaultIdentifierQuote = `"{{.Value}}"` defaultValueSeparator = `, ` defaultValueQuote = `'{{.}}'` defaultAndKeyword = `AND` @@ -21,7 +21,7 @@ const ( defaultColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` defaultTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` defaultColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - defaultSortByColumnLayout = `{{.Column}} {{.Sort}}` + defaultSortByColumnLayout = `{{.Column}} {{.Order}}` defaultOrderByLayout = ` {{if .SortColumns}} @@ -72,7 +72,7 @@ const ( {{ .Where }} ` - defaultSelectCountLayout = ` + defaultCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -143,7 +143,7 @@ var defaultTemplate = &Template{ TruncateLayout: defaultTruncateLayout, DropDatabaseLayout: defaultDropDatabaseLayout, DropTableLayout: defaultDropTableLayout, - SelectCountLayout: defaultSelectCountLayout, + CountLayout: defaultCountLayout, GroupByLayout: defaultGroupByLayout, Cache: cache.NewCache(), } diff --git a/util/sqlgen/group_by.go b/util/sqlgen/group_by.go index 28aa812fce9d25ace9f726c627530568aa83a9a2..fe8ed3f34c7d0b05e343432a8561de96a3ae9656 100644 --- a/util/sqlgen/group_by.go +++ b/util/sqlgen/group_by.go @@ -1,31 +1,50 @@ package sqlgen -type GroupBy Columns +import ( + "fmt" +) + +// GroupBy represents a SQL's "group by" statement. +type GroupBy struct { + Columns Fragment + hash string +} -type groupBy_s struct { +type groupByT struct { GroupColumns string } -func (self GroupBy) Hash() string { - return `GroupBy(` + Columns(self).Hash() + `)` +// Hash returns a unique identifier. +func (g *GroupBy) Hash() string { + if g.hash == "" { + if g.Columns != nil { + g.hash = fmt.Sprintf(`GroupBy(%s)`, g.Columns.Hash()) + } + } + return g.hash +} + +// GroupByColumns creates and returns a GroupBy with the given column. +func GroupByColumns(columns ...Fragment) *GroupBy { + return &GroupBy{Columns: JoinColumns(columns...)} } -func (self GroupBy) Compile(layout *Template) (compiled string) { +// Compile transforms the GroupBy into an equivalent SQL representation. +func (g *GroupBy) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { + if c, ok := layout.Read(g); ok { return c } - if len(self) > 0 { - - data := groupBy_s{ - GroupColumns: Columns(self).Compile(layout), + if g.Columns != nil { + data := groupByT{ + GroupColumns: g.Columns.Compile(layout), } compiled = mustParse(layout.GroupByLayout, data) } - layout.Write(self, compiled) + layout.Write(g, compiled) return } diff --git a/util/sqlgen/group_by_test.go b/util/sqlgen/group_by_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c6c6a6f3e34d7bb4c99ac35e3e020e5f1506a389 --- /dev/null +++ b/util/sqlgen/group_by_test.go @@ -0,0 +1,73 @@ +package sqlgen + +import ( + "testing" +) + +func TestGroupBy(t *testing.T) { + var s, e string + + columns := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + + s = columns.Compile(defaultTemplate) + e = `GROUP BY "id", "customer", "service_id", "role"."name", "role"."id"` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkGroupByColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = GroupByColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkGroupByHash(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkGroupByCompile(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkGroupByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + c.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/interfaces.go b/util/sqlgen/interfaces.go index 234742c40098cf1696bbe100b25faafa6070482a..8d6cb109034d5890bb47b8aeb25395640e7b9398 100644 --- a/util/sqlgen/interfaces.go +++ b/util/sqlgen/interfaces.go @@ -4,8 +4,8 @@ import ( "upper.io/cache" ) -type cc interface { - cache.Cacheable +type Fragment interface { + cache.Hashable compilable } diff --git a/util/sqlgen/main.go b/util/sqlgen/main.go deleted file mode 100644 index 098c113ced577d2df119eded5cd585b42ea513ce..0000000000000000000000000000000000000000 --- a/util/sqlgen/main.go +++ /dev/null @@ -1,44 +0,0 @@ -package sqlgen - -import ( - "bytes" - "text/template" -) - -type Type uint - -const ( - SqlTruncate = iota - SqlDropTable - SqlDropDatabase - SqlSelectCount - SqlInsert - SqlSelect - SqlUpdate - SqlDelete -) - -type ( - Limit int - Offset int - Extra string -) - -var parsedTemplates = make(map[string]*template.Template) - -func mustParse(text string, data interface{}) (compiled string) { - var b bytes.Buffer - var ok bool - - if _, ok = parsedTemplates[text]; ok == false { - parsedTemplates[text] = template.Must(template.New("").Parse(text)) - } - - if err := parsedTemplates[text].Execute(&b, data); err != nil { - panic("There was an error compiling the following template:\n" + text + "\nError was: " + err.Error()) - } - - compiled = b.String() - - return -} diff --git a/util/sqlgen/main_test.go b/util/sqlgen/main_test.go deleted file mode 100644 index c30c851c4103e52b76fef1f8da5d5d4e56e199d2..0000000000000000000000000000000000000000 --- a/util/sqlgen/main_test.go +++ /dev/null @@ -1,662 +0,0 @@ -package sqlgen - -import ( - "testing" -) - -func TestTruncateTable(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlTruncate, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `TRUNCATE TABLE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestDropTable(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlDropTable, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `DROP TABLE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestDropDatabase(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlDropDatabase, - Database: Database{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `DROP DATABASE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectCount(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectCountRelation(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"information_schema.tables"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "information_schema"."tables"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectCountWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFromAlias(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table.name AS foo"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFromRawWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table.name AS foo"}, - Where: Where{ - Raw{"foo.id = bar.foo_id"}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table.name AS foo"}, - Where: Where{ - Raw{"foo.id = bar.foo_id"}, - Raw{"baz.id = exp.baz_id"}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFromMany(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"first.table AS foo, second.table as BAR, third.table aS baz"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectArtistNameFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"artist"}, - Columns: Columns{ - {"artist.name"}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "artist"."name" FROM "artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectRawFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{`artist`}, - Columns: Columns{ - {`artist.name`}, - {Raw{`CONCAT(artist.name, " ", artist.last_name)`}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWithLimitOffset(t *testing.T) { - var s, e string - var stmt Statement - - // LIMIT only. - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Limit: 42, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // OFFSET only. - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Offset: 17, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // LIMIT AND OFFSET. - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Limit: 42, - Offset: 17, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestGroupBy(t *testing.T) { - var s, e string - var stmt Statement - - // Simple GROUP BY - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - GroupBy: GroupBy{ - Column{"foo"}, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - GroupBy: GroupBy{ - Column{"foo"}, - Column{"bar"}, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWithOrderBy(t *testing.T) { - var s, e string - var stmt Statement - - // Simple ORDER BY - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns: SortColumns{ - SortColumn{Column: Column{"foo"}}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY field ASC - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - SortColumn{Column{"foo"}, SqlSortAsc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY field DESC - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - {Column{"foo"}, SqlSortDesc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY many fields - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - {Column{"foo"}, SqlSortDesc}, - {Column{"bar"}, SqlSortAsc}, - {Column{"baz"}, SqlSortDesc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY function - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - {Column{Raw{"FOO()"}}, SqlSortDesc}, - {Column{Raw{"BAR()"}}, SqlSortAsc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - Limit: 10, - Offset: 23, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestDelete(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlDelete, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `DELETE FROM "table_name" WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestUpdate(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlUpdate, - Table: Table{"table_name"}, - ColumnValues: ColumnValues{ - {Column{"foo"}, "=", Value{76}}, - }, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - stmt = Statement{ - Type: SqlUpdate, - Table: Table{"table_name"}, - ColumnValues: ColumnValues{ - {Column{"foo"}, "=", Value{76}}, - {Column{"bar"}, "=", Value{Raw{"88"}}}, - }, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestInsert(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlInsert, - Table: Table{"table_name"}, - Columns: Columns{ - Column{"foo"}, - Column{"bar"}, - Column{"baz"}, - }, - Values: Values{ - Value{"1"}, - Value{2}, - Value{Raw{"3"}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestInsertExtra(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlInsert, - Table: Table{"table_name"}, - Extra: "RETURNING id", - Columns: Columns{ - Column{"foo"}, - Column{"bar"}, - Column{"baz"}, - }, - Values: Values{ - Value{"1"}, - Value{2}, - Value{Raw{"3"}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING id` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} diff --git a/util/sqlgen/order_by.go b/util/sqlgen/order_by.go index a88fe69a3cb5e7077ece8af0e4f49d94d1b5c86a..d437cde386ba7df348deff2f302b73e088f378fe 100644 --- a/util/sqlgen/order_by.go +++ b/util/sqlgen/order_by.go @@ -1,112 +1,163 @@ package sqlgen import ( + "fmt" "strings" ) +// Order represents the order in which SQL results are sorted. +type Order uint8 + +// Possible values for Order +const ( + DefaultOrder = Order(iota) + Ascendent + Descendent +) + +// SortColumn represents the column-order relation in an ORDER BY clause. type SortColumn struct { - Column - Sort + Column Fragment + Order + hash string } -type sortColumn_s struct { +type sortColumnT struct { Column string - Sort string + Order string } -type SortColumns []SortColumn +// SortColumns represents the columns in an ORDER BY clause. +type SortColumns struct { + Columns []Fragment + hash string +} -func (self SortColumn) Hash() string { - return `SortColumn(` + self.Column.Hash() + `;` + self.Sort.Hash() + `)` +// OrderBy represents an ORDER BY clause. +type OrderBy struct { + SortColumns Fragment + hash string } -func (self SortColumns) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `SortColumns(` + strings.Join(hash, `,`) + `)` +type orderByT struct { + SortColumns string } -func (self SortColumns) Compile(layout *Template) string { - l := len(self) - s := make([]string, 0, l) - for i := 0; i < l; i++ { - s = append(s, self[i].Compile(layout)) +// JoinSortColumns creates and returns an array of column-order relations. +func JoinSortColumns(values ...Fragment) *SortColumns { + return &SortColumns{Columns: values} +} + +// JoinWithOrderBy creates an returns an OrderBy using the given SortColumns. +func JoinWithOrderBy(sc *SortColumns) *OrderBy { + return &OrderBy{SortColumns: sc} +} + +// Hash returns a unique identifier. +func (s *SortColumn) Hash() string { + if s.hash == "" { + s.hash = fmt.Sprintf(`SortColumn{Column:%s, Order:%s}`, s.Column.Hash(), s.Order.Hash()) } - return strings.Join(s, layout.IdentifierSeparator) + return s.hash } -func (self SortColumn) Compile(layout *Template) (compiled string) { +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *SortColumn) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { + if c, ok := layout.Read(s); ok { return c } - data := sortColumn_s{ - Column: self.Column.Compile(layout), - Sort: self.Sort.Compile(layout), + data := sortColumnT{ + Column: s.Column.Compile(layout), + Order: s.Order.Compile(layout), } compiled = mustParse(layout.SortByColumnLayout, data) - layout.Write(self, compiled) + layout.Write(s, compiled) + return } -type OrderBy struct { - SortColumns +// Hash returns a unique identifier. +func (s *SortColumns) Hash() string { + if s.hash == "" { + h := make([]string, len(s.Columns)) + for i := range s.Columns { + h[i] = s.Columns[i].Hash() + } + s.hash = fmt.Sprintf(`SortColumns(%s)`, strings.Join(h, `, `)) + } + return s.hash } -type orderBy_s struct { - SortColumns string +// Compile transforms the SortColumns into an equivalent SQL representation. +func (s *SortColumns) Compile(layout *Template) (compiled string) { + + if z, ok := layout.Read(s); ok { + return z + } + + z := make([]string, len(s.Columns)) + + for i := range s.Columns { + z[i] = s.Columns[i].Compile(layout) + } + + compiled = strings.Join(z, layout.IdentifierSeparator) + + layout.Write(s, compiled) + + return } -func (self OrderBy) Hash() string { - return `OrderBy(` + self.SortColumns.Hash() + `)` +// Hash returns a unique identifier. +func (s *OrderBy) Hash() string { + if s.hash == "" { + if s.SortColumns != nil { + s.hash = `OrderBy(` + s.SortColumns.Hash() + `)` + } + } + return s.hash } -func (self OrderBy) Compile(layout *Template) (compiled string) { +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *OrderBy) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(s); ok { + return z } - if len(self.SortColumns) > 0 { - data := orderBy_s{ - SortColumns: self.SortColumns.Compile(layout), + if s.SortColumns != nil { + data := orderByT{ + SortColumns: s.SortColumns.Compile(layout), } compiled = mustParse(layout.OrderByLayout, data) } - layout.Write(self, compiled) + layout.Write(s, compiled) return } -type Sort uint8 - -const ( - SqlSortNone = iota - SqlSortAsc - SqlSortDesc -) - -func (self Sort) Hash() string { - switch self { - case SqlSortAsc: - return `Sort(1)` - case SqlSortDesc: - return `Sort(2)` +// Hash returns a unique identifier. +func (s Order) Hash() string { + switch s { + case Ascendent: + return `Order{ASC}` + case Descendent: + return `Order{DESC}` } - return `Sort(0)` + return `Order{DEFAULT}` } -func (self Sort) Compile(layout *Template) string { - switch self { - case SqlSortAsc: +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s Order) Compile(layout *Template) string { + switch s { + case Ascendent: return layout.AscKeyword - case SqlSortDesc: + case Descendent: return layout.DescKeyword } return "" diff --git a/util/sqlgen/order_by_test.go b/util/sqlgen/order_by_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bbb7ac8421c99406c7d54765a3db5137c2488ffa --- /dev/null +++ b/util/sqlgen/order_by_test.go @@ -0,0 +1,143 @@ +package sqlgen + +import ( + "testing" +) + +func TestOrderBy(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + + s := o.Compile(defaultTemplate) + e := `ORDER BY "foo"` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestOrderByDesc(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + ), + ) + + s := o.Compile(defaultTemplate) + e := `ORDER BY "foo" DESC` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkOrderBy(b *testing.B) { + for i := 0; i < b.N; i++ { + JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + } +} + +func BenchmarkOrderByHash(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + for i := 0; i < b.N; i++ { + o.Hash() + } +} + +func BenchmarkCompileOrderByCompile(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + for i := 0; i < b.N; i++ { + o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompile(b *testing.B) { + o := Descendent + for i := 0; i < b.N; i++ { + o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := Descendent + o.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnHash(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnCompile(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + for i := 0; i < b.N; i++ { + s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := &SortColumn{Column: &Column{Name: "foo"}} + s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsHash(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnsCompile(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + for i := 0; i < b.N; i++ { + s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + s.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/raw.go b/util/sqlgen/raw.go index cda0e66bd25c0805badfdf648b7f5b7ab6481fcf..ca16c26c1f767c1d06b4cdca12ca83f0f94adf1f 100644 --- a/util/sqlgen/raw.go +++ b/util/sqlgen/raw.go @@ -1,17 +1,38 @@ package sqlgen +import ( + "fmt" +) + +var ( + _ = fmt.Stringer(&Raw{}) +) + +// Raw represents a value that is meant to be used in a query without escaping. type Raw struct { - Raw string + Value string // Value should not be modified after assigned. + hash string +} + +// RawValue creates and returns a new raw value. +func RawValue(v string) *Raw { + return &Raw{Value: v} } -func (self Raw) Hash() string { - return `Raw(` + self.Raw + `)` +// Hash returns a unique identifier. +func (r *Raw) Hash() string { + if r.hash == "" { + r.hash = `Raw{Value:"` + r.Value + `"}` + } + return r.hash } -func (self Raw) Compile(*Template) string { - return self.Raw +// Compile returns the raw value. +func (r *Raw) Compile(*Template) string { + return r.Value } -func (self Raw) String() string { - return self.Raw +// String returns the raw value. +func (r *Raw) String() string { + return r.Value } diff --git a/util/sqlgen/raw_test.go b/util/sqlgen/raw_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9ad57ad6c04a37761ddddc1532583b59a579cb63 --- /dev/null +++ b/util/sqlgen/raw_test.go @@ -0,0 +1,72 @@ +package sqlgen + +import ( + "fmt" + "testing" +) + +func TestRawString(t *testing.T) { + var s, e string + + raw := &Raw{Value: "foo"} + + s = raw.Compile(defaultTemplate) + e = `foo` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestRawCompile(t *testing.T) { + var s, e string + + raw := &Raw{Value: "foo"} + + s = raw.Compile(defaultTemplate) + e = `foo` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestRawHash(t *testing.T) { + var s, e string + + raw := &Raw{Value: "foo"} + + s = raw.Hash() + e = fmt.Sprintf(`Raw{Value:"%s"}`, raw) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkRawCreate(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Raw{Value: "foo"} + } +} + +func BenchmarkRawString(b *testing.B) { + raw := &Raw{Value: "foo"} + for i := 0; i < b.N; i++ { + raw.String() + } +} + +func BenchmarkRawCompile(b *testing.B) { + raw := &Raw{Value: "foo"} + for i := 0; i < b.N; i++ { + raw.Compile(defaultTemplate) + } +} + +func BenchmarkRawHash(b *testing.B) { + raw := &Raw{Value: "foo"} + for i := 0; i < b.N; i++ { + raw.Hash() + } +} diff --git a/util/sqlgen/statement.go b/util/sqlgen/statement.go index b7b4ec472bcc2214f768b63cf854437307c8c002..16b8800d72a2805b6a29986cac5de4f7ceb9c2af 100644 --- a/util/sqlgen/statement.go +++ b/util/sqlgen/statement.go @@ -2,24 +2,29 @@ package sqlgen import ( "strconv" + "strings" + + "upper.io/cache" ) +// Statement represents different kinds of SQL statements. type Statement struct { Type - Table - Database + Table Fragment + Database Fragment Limit Offset - Columns - Values - ColumnValues - OrderBy - GroupBy + Columns Fragment + Values Fragment + ColumnValues Fragment + OrderBy Fragment + GroupBy Fragment Extra - Where + Where Fragment + hash string } -type statement_s struct { +type statementT struct { Table string Database string Limit @@ -33,64 +38,86 @@ type statement_s struct { Where string } -func (self Statement) Hash() string { - hash := `Statement(` + - strconv.Itoa(int(self.Type)) + `;` + - self.Table.Hash() + `;` + - self.Database.Hash() + `;` + - strconv.Itoa(int(self.Limit)) + `;` + - strconv.Itoa(int(self.Offset)) + `;` + - self.Columns.Hash() + `;` + - self.Values.Hash() + `;` + - self.ColumnValues.Hash() + `;` + - self.OrderBy.Hash() + `;` + - self.GroupBy.Hash() + `;` + - string(self.Extra) + `;` + - self.Where.Hash() + - `)` - return hash +func (layout *Template) doCompile(c Fragment) string { + if c != nil { + return c.Compile(layout) + } + return "" +} + +func (s Statement) getHash(h cache.Hashable) string { + if h != nil { + return h.Hash() + } + return "" +} + +// Hash returns a unique identifier. +func (s *Statement) Hash() string { + if s.hash == "" { + parts := strings.Join([]string{ + strconv.Itoa(int(s.Type)), + s.getHash(s.Table), + s.getHash(s.Database), + strconv.Itoa(int(s.Limit)), + strconv.Itoa(int(s.Offset)), + s.getHash(s.Columns), + s.getHash(s.Values), + s.getHash(s.ColumnValues), + s.getHash(s.OrderBy), + s.getHash(s.GroupBy), + string(s.Extra), + s.getHash(s.Where), + }, ";") + + s.hash = `Statement(` + parts + `)` + } + return s.hash } -func (self *Statement) Compile(layout *Template) (compiled string) { +// Compile transforms the Statement into an equivalent SQL query. +func (s *Statement) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(s); ok { + return z } - data := statement_s{ - Table: self.Table.Compile(layout), - Database: self.Database.Compile(layout), - Limit: self.Limit, - Offset: self.Offset, - Columns: self.Columns.Compile(layout), - Values: self.Values.Compile(layout), - ColumnValues: self.ColumnValues.Compile(layout), - OrderBy: self.OrderBy.Compile(layout), - GroupBy: self.GroupBy.Compile(layout), - Extra: string(self.Extra), - Where: self.Where.Compile(layout), + data := statementT{ + Table: layout.doCompile(s.Table), + Database: layout.doCompile(s.Database), + Limit: s.Limit, + Offset: s.Offset, + Columns: layout.doCompile(s.Columns), + Values: layout.doCompile(s.Values), + ColumnValues: layout.doCompile(s.ColumnValues), + OrderBy: layout.doCompile(s.OrderBy), + GroupBy: layout.doCompile(s.GroupBy), + Extra: string(s.Extra), + Where: layout.doCompile(s.Where), } - switch self.Type { - case SqlTruncate: + switch s.Type { + case Truncate: compiled = mustParse(layout.TruncateLayout, data) - case SqlDropTable: + case DropTable: compiled = mustParse(layout.DropTableLayout, data) - case SqlDropDatabase: + case DropDatabase: compiled = mustParse(layout.DropDatabaseLayout, data) - case SqlSelectCount: - compiled = mustParse(layout.SelectCountLayout, data) - case SqlSelect: + case Count: + compiled = mustParse(layout.CountLayout, data) + case Select: compiled = mustParse(layout.SelectLayout, data) - case SqlDelete: + case Delete: compiled = mustParse(layout.DeleteLayout, data) - case SqlUpdate: + case Update: compiled = mustParse(layout.UpdateLayout, data) - case SqlInsert: + case Insert: compiled = mustParse(layout.InsertLayout, data) + default: + panic("Unknown template type.") } - layout.Write(self, compiled) + layout.Write(s, compiled) return compiled } diff --git a/util/sqlgen/statement_test.go b/util/sqlgen/statement_test.go new file mode 100644 index 0000000000000000000000000000000000000000..34859a2d5ee18a0a56e647850aae60d027ed77d1 --- /dev/null +++ b/util/sqlgen/statement_test.go @@ -0,0 +1,757 @@ +package sqlgen + +import ( + "regexp" + "strings" + "testing" +) + +var ( + reInvisible = regexp.MustCompile(`[\t\n\r]`) + reSpace = regexp.MustCompile(`\s+`) +) + +func trim(a string) string { + a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") + a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") + return a +} + +func TestTruncateTable(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Truncate, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `TRUNCATE TABLE "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDropTable(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: DropTable, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `DROP TABLE "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDropDatabase(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: DropDatabase, + Database: &Database{Name: "table_name"}, + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `DROP DATABASE "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestCount(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Count, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT COUNT(1) AS _t FROM "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestCountRelation(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Count, + Table: TableWithName("information_schema.tables"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT COUNT(1) AS _t FROM "information_schema"."tables"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestCountWhere(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(RawValue("7"))}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFromAlias(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table"."name" AS "foo"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFromRawWhere(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + stmt = Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + &Raw{Value: "baz.id = exp.baz_id"}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFromMany(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectArtistNameFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("artist"), + Columns: JoinColumns( + &Column{Name: "artist.name"}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "artist"."name" FROM "artist"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectRawFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName(`artist`), + Columns: JoinColumns( + &Column{Name: `artist.name`}, + &Column{Name: Raw{Value: `CONCAT(artist.name, " ", artist.last_name)`}}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWithLimitOffset(t *testing.T) { + var s, e string + var stmt Statement + + // LIMIT only. + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // OFFSET only. + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Offset: 17, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // LIMIT AND OFFSET. + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Offset: 17, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestStatementGroupBy(t *testing.T) { + var s, e string + var stmt Statement + + // Simple GROUP BY + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWithOrderBy(t *testing.T) { + var s, e string + var stmt Statement + + // Simple ORDER BY + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY field ASC + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY field DESC + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY many fields + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + &SortColumn{Column: &Column{Name: "bar"}, Order: Ascendent}, + &SortColumn{Column: &Column{Name: "baz"}, Order: Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY function + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: Raw{Value: "FOO()"}}, Order: Descendent}, + &SortColumn{Column: &Column{Name: Raw{Value: "BAR()"}}, Order: Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWhere(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + Limit: 10, + Offset: 23, + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDelete(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Delete, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `DELETE FROM "table_name" WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestUpdate(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + stmt = Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + &ColumnValue{Column: &Column{Name: "bar"}, Operator: "=", Value: NewValue(Raw{Value: "88"})}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestInsert(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestInsertExtra(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Extra: "RETURNING id", + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING id` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkStatementSimpleQuery(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + ), + } + + for i := 0; i < b.N; i++ { + _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementSimpleQueryHash(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + ), + } + + for i := 0; i < b.N; i++ { + _ = stmt.Hash() + } +} + +func BenchmarkStatementSimpleQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + ), + } + _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQuery(b *testing.B) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + + for i := 0; i < b.N; i++ { + _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + _ = stmt.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/table.go b/util/sqlgen/table.go index 257f65ba8d500f995d0fbedf1d11ebe7ba5b0e68..43dcda62cf50ba9d85020915916aa211954c186f 100644 --- a/util/sqlgen/table.go +++ b/util/sqlgen/table.go @@ -13,6 +13,7 @@ type tableT struct { // Table struct represents a SQL table. type Table struct { Name interface{} + hash string } func quotedTableName(layout *Template, input string) string { @@ -33,7 +34,7 @@ func quotedTableName(layout *Template, input string) string { for i := range nameChunks { // nameChunks[i] = strings.TrimSpace(nameChunks[i]) nameChunks[i] = trimString(nameChunks[i]) - nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{nameChunks[i]}) + nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) } name = strings.Join(nameChunks, layout.ColumnSeparator) @@ -43,28 +44,44 @@ func quotedTableName(layout *Template, input string) string { if len(chunks) > 1 { // alias = strings.TrimSpace(chunks[1]) alias = trimString(chunks[1]) - alias = mustParse(layout.IdentifierQuote, Raw{alias}) + alias = mustParse(layout.IdentifierQuote, Raw{Value: alias}) } return mustParse(layout.TableAliasLayout, tableT{name, alias}) } +// TableWithName creates an returns a Table with the given name. +func TableWithName(name string) *Table { + return &Table{Name: name} +} + // Hash returns a string hash of the table value. -func (t Table) Hash() string { - switch t := t.Name.(type) { - case cc: - return `Table(` + t.Hash() + `)` - case string: - return `Table(` + t + `)` +func (t *Table) Hash() string { + if t.hash == "" { + var s string + + switch v := t.Name.(type) { + case Fragment: + s = v.Hash() + case fmt.Stringer: + s = v.String() + case string: + s = v + default: + s = fmt.Sprintf("%v", t.Name) + } + + t.hash = fmt.Sprintf(`Table{Name:%q}`, s) } - return fmt.Sprintf(`Table(%v)`, t.Name) + + return t.hash } // Compile transforms a table struct into a SQL chunk. -func (t Table) Compile(layout *Template) (compiled string) { +func (t *Table) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(t); ok { - return c + if z, ok := layout.Read(t); ok { + return z } switch value := t.Name.(type) { diff --git a/util/sqlgen/table_test.go b/util/sqlgen/table_test.go index dbae8be34841b8255a383c519fbd32d8b0444f43..df9528de00911d4e5d1eb11ab99526db93c55d11 100644 --- a/util/sqlgen/table_test.go +++ b/util/sqlgen/table_test.go @@ -6,9 +6,8 @@ import ( func TestTableSimple(t *testing.T) { var s, e string - var table Table - table = Table{"artist"} + table := TableWithName("artist") s = trim(table.Compile(defaultTemplate)) e = `"artist"` @@ -20,9 +19,8 @@ func TestTableSimple(t *testing.T) { func TestTableCompound(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo"} + table := TableWithName("artist.foo") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo"` @@ -34,9 +32,8 @@ func TestTableCompound(t *testing.T) { func TestTableCompoundAlias(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo AS baz"} + table := TableWithName("artist.foo AS baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo" AS "baz"` @@ -48,9 +45,8 @@ func TestTableCompoundAlias(t *testing.T) { func TestTableImplicitAlias(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo baz"} + table := TableWithName("artist.foo baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo" AS "baz"` @@ -62,9 +58,8 @@ func TestTableImplicitAlias(t *testing.T) { func TestTableMultiple(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo, artist.bar, artist.baz"} + table := TableWithName("artist.foo, artist.bar, artist.baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo", "artist"."bar", "artist"."baz"` @@ -76,9 +71,8 @@ func TestTableMultiple(t *testing.T) { func TestTableMultipleAlias(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo AS foo, artist.bar as bar, artist.baz As baz"} + table := TableWithName("artist.foo AS foo, artist.bar as bar, artist.baz As baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo" AS "foo", "artist"."bar" AS "bar", "artist"."baz" AS "baz"` @@ -90,9 +84,8 @@ func TestTableMultipleAlias(t *testing.T) { func TestTableMinimal(t *testing.T) { var s, e string - var table Table - table = Table{"a"} + table := TableWithName("a") s = trim(table.Compile(defaultTemplate)) e = `"a"` @@ -104,9 +97,8 @@ func TestTableMinimal(t *testing.T) { func TestTableEmpty(t *testing.T) { var s, e string - var table Table - table = Table{""} + table := TableWithName("") s = trim(table.Compile(defaultTemplate)) e = `` @@ -115,3 +107,30 @@ func TestTableEmpty(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkTableWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = TableWithName("foo") + } +} + +func BenchmarkTableHash(b *testing.B) { + t := TableWithName("name") + for i := 0; i < b.N; i++ { + t.Hash() + } +} + +func BenchmarkTableCompile(b *testing.B) { + t := TableWithName("name") + for i := 0; i < b.N; i++ { + t.Compile(defaultTemplate) + } +} + +func BenchmarkTableCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + t := TableWithName("name") + t.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/template.go b/util/sqlgen/template.go index e3ff175cb0619793fe32f907581eb7de2e687104..bea487f39f26fb4061b9aeef3c81bd2973f7fba5 100644 --- a/util/sqlgen/template.go +++ b/util/sqlgen/template.go @@ -1,9 +1,41 @@ package sqlgen import ( + "bytes" + "text/template" + "upper.io/cache" ) +// Type is the type of SQL query the statement represents. +type Type uint + +// Values for Type. +const ( + Truncate = Type(iota) + DropTable + DropDatabase + Count + Insert + Select + Update + Delete +) + +type ( + // Limit represents the SQL limit in a query. + Limit int + // Offset represents the SQL offset in a query. + Offset int + // Extra represents any custom SQL that is to be appended to the query. + Extra string +) + +var ( + parsedTemplates = make(map[string]*template.Template) +) + +// Template is an SQL template. type Template struct { ColumnSeparator string IdentifierSeparator string @@ -16,6 +48,7 @@ type Template struct { DescKeyword string AscKeyword string DefaultOperator string + AssignmentOperator string ClauseGroup string ClauseOperator string ColumnValue string @@ -31,7 +64,22 @@ type Template struct { TruncateLayout string DropDatabaseLayout string DropTableLayout string - SelectCountLayout string + CountLayout string GroupByLayout string *cache.Cache } + +func mustParse(text string, data interface{}) string { + var b bytes.Buffer + var ok bool + + if _, ok = parsedTemplates[text]; !ok { + parsedTemplates[text] = template.Must(template.New("").Parse(text)) + } + + if err := parsedTemplates[text].Execute(&b, data); err != nil { + panic("There was an error compiling the following template:\n" + text + "\nError was: " + err.Error()) + } + + return b.String() +} diff --git a/util/sqlgen/util_test.go b/util/sqlgen/util_test.go deleted file mode 100644 index 0551cf34fea3e9d0e088a79b8d313daff426cf1e..0000000000000000000000000000000000000000 --- a/util/sqlgen/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package sqlgen - -import ( - "regexp" - "strings" -) - -var ( - reInvisible = regexp.MustCompile(`[\t\n\r]`) - reSpace = regexp.MustCompile(`\s+`) -) - -func trim(a string) string { - a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") - a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") - return a -} diff --git a/util/sqlgen/utilities.go b/util/sqlgen/utilities.go index 40bfe1e04ab17d458df54f2b5eac3f68f69e1f29..305ab209026dfac06483cfb5de8a78a0377eb97a 100644 --- a/util/sqlgen/utilities.go +++ b/util/sqlgen/utilities.go @@ -10,50 +10,61 @@ const ( stageClose ) -func isSpace(in byte) bool { +// isBlankSymbol returns true if the given byte is either space, tab, carriage +// return or newline. +func isBlankSymbol(in byte) bool { return in == ' ' || in == '\t' || in == '\r' || in == '\n' } -func trimString(in string) string { +// trimString returns a slice of s with a leading and trailing blank symbols +// (as defined by isBlankSymbol) removed. +func trimString(s string) string { - start, end := 0, len(in)-1 + // This conversion is rather slow. + // return string(trimBytes([]byte(s))) - // Where do we start cutting? - for ; start <= end; start++ { - if isSpace(in[start]) == false { - break - } + start, end := 0, len(s)-1 + + if end < start { + return "" } - // Where do we end cutting? - for ; end >= start; end-- { - if isSpace(in[end]) == false { - break + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return "" } } - return in[start : end+1] + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] } -func trimByte(in []byte) []byte { +// trimBytes returns a slice of s with a leading and trailing blank symbols (as +// defined by isBlankSymbol) removed. +func trimBytes(s []byte) []byte { - start, end := 0, len(in)-1 + start, end := 0, len(s)-1 - // Where do we start cutting? - for ; start <= end; start++ { - if isSpace(in[start]) == false { - break - } + if end < start { + return []byte{} } - // Where do we end cutting? - for ; end >= start; end-- { - if isSpace(in[end]) == false { - break + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return []byte{} } } - return in[start : end+1] + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] } /* @@ -95,15 +106,12 @@ func separateByComma(in string) (out []string) { // Separates by spaces, ignoring spaces too. func separateBySpace(in string) (out []string) { - l := len(in) - - if l == 0 { + if len(in) == 0 { return []string{""} } - out = make([]string, 0, l) - pre := strings.Split(in, " ") + out = make([]string, 0, len(pre)) for i := range pre { pre[i] = trimString(pre[i]) @@ -119,7 +127,7 @@ func separateByAS(in string) (out []string) { out = []string{} if len(in) < 6 { - // Min expression: "a AS b" + // The minimum expression with the AS keyword is "x AS y", 6 chars. return []string{in} } @@ -129,7 +137,7 @@ func separateByAS(in string) (out []string) { var end int for end = start; end <= lim; end++ { - if end > 3 && isSpace(in[end]) && isSpace(in[end-3]) { + if end > 3 && isBlankSymbol(in[end]) && isBlankSymbol(in[end-3]) { if (in[end-1] == 's' || in[end-1] == 'S') && (in[end-2] == 'a' || in[end-2] == 'A') { break } diff --git a/util/sqlgen/utilities_test.go b/util/sqlgen/utilities_test.go index 87cff171c458d7b80a817a5c58d1d9398dfba091..4d5ec550618b0759bcab511f92cbfcc06c67fcdc 100644 --- a/util/sqlgen/utilities_test.go +++ b/util/sqlgen/utilities_test.go @@ -5,45 +5,63 @@ import ( "regexp" "strings" "testing" + "unicode" ) -func TestUtilIsSpace(t *testing.T) { - if isSpace(' ') == false { +const ( + blankSymbol = ' ' + stringWithCommas = "Hello,,World!,Enjoy" + stringWithSpaces = " Hello World! Enjoy" + stringWithASKeyword = "table.Name AS myTableAlias" +) + +var ( + bytesWithLeadingBlanks = []byte(" Hello world! ") + stringWithLeadingBlanks = string(bytesWithLeadingBlanks) +) + +func TestUtilIsBlankSymbol(t *testing.T) { + if isBlankSymbol(' ') == false { t.Fail() } - if isSpace('\n') == false { + if isBlankSymbol('\n') == false { t.Fail() } - if isSpace('\t') == false { + if isBlankSymbol('\t') == false { t.Fail() } - if isSpace('\r') == false { + if isBlankSymbol('\r') == false { t.Fail() } - if isSpace('x') == true { + if isBlankSymbol('x') == true { t.Fail() } } -func TestUtilTrimByte(t *testing.T) { +func TestUtilTrimBytes(t *testing.T) { var trimmed []byte - trimmed = trimByte([]byte(" \t\nHello World! \n")) + trimmed = trimBytes([]byte(" \t\nHello World! \n")) if string(trimmed) != "Hello World!" { t.Fatalf("Got: %s\n", string(trimmed)) } - trimmed = trimByte([]byte("Nope")) + trimmed = trimBytes([]byte("Nope")) if string(trimmed) != "Nope" { t.Fatalf("Got: %s\n", string(trimmed)) } - trimmed = trimByte([]byte("")) + trimmed = trimBytes([]byte("")) if string(trimmed) != "" { t.Fatalf("Got: %s\n", string(trimmed)) } - trimmed = trimByte(nil) + trimmed = trimBytes([]byte(" ")) + if string(trimmed) != "" { + t.Fatalf("Got: %s\n", string(trimmed)) + } + + trimmed = trimBytes(nil) if string(trimmed) != "" { t.Fatalf("Got: %s\n", string(trimmed)) } @@ -191,81 +209,76 @@ func TestUtilSeparateByAS(t *testing.T) { } } -func BenchmarkUtilIsSpace(b *testing.B) { +func BenchmarkUtilIsBlankSymbol(b *testing.B) { for i := 0; i < b.N; i++ { - _ = isSpace(' ') + _ = isBlankSymbol(blankSymbol) } } -func BenchmarkUtilTrimByte(b *testing.B) { - s := []byte(" Hello world! ") +func BenchmarkUtilStdlibIsBlankSymbol(b *testing.B) { for i := 0; i < b.N; i++ { - _ = trimByte(s) + _ = unicode.IsSpace(blankSymbol) } } -func BenchmarkUtilTrimString(b *testing.B) { - s := " Hello world! " +func BenchmarkUtilTrimBytes(b *testing.B) { for i := 0; i < b.N; i++ { - _ = trimString(s) + _ = trimBytes(bytesWithLeadingBlanks) + } +} +func BenchmarkUtilStdlibBytesTrimSpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = bytes.TrimSpace(bytesWithLeadingBlanks) } } -func BenchmarkUtilStdBytesTrimSpace(b *testing.B) { - s := []byte(" Hello world! ") +func BenchmarkUtilTrimString(b *testing.B) { for i := 0; i < b.N; i++ { - _ = bytes.TrimSpace(s) + _ = trimString(stringWithLeadingBlanks) } } -func BenchmarkUtilStdStringsTrimSpace(b *testing.B) { - s := " Hello world! " +func BenchmarkUtilStdlibStringsTrimSpace(b *testing.B) { for i := 0; i < b.N; i++ { - _ = strings.TrimSpace(s) + _ = strings.TrimSpace(stringWithLeadingBlanks) } } func BenchmarkUtilSeparateByComma(b *testing.B) { - s := "Hello,,World!,Enjoy" for i := 0; i < b.N; i++ { - _ = separateByComma(s) + _ = separateByComma(stringWithCommas) } } -func BenchmarkUtilSeparateBySpace(b *testing.B) { - s := " Hello World! Enjoy" +func BenchmarkUtilRegExpSeparateByComma(b *testing.B) { + sep := regexp.MustCompile(`\s*?,\s*?`) for i := 0; i < b.N; i++ { - _ = separateBySpace(s) + _ = sep.Split(stringWithCommas, -1) } } -func BenchmarkUtilSeparateByAS(b *testing.B) { - s := "table.Name AS myTableAlias" +func BenchmarkUtilSeparateBySpace(b *testing.B) { for i := 0; i < b.N; i++ { - _ = separateByAS(s) + _ = separateBySpace(stringWithSpaces) } } -func BenchmarkUtilSeparateByCommaRegExp(b *testing.B) { - sep := regexp.MustCompile(`\s*?,\s*?`) - s := "Hello,,World!,Enjoy" +func BenchmarkUtilRegExpSeparateBySpace(b *testing.B) { + sep := regexp.MustCompile(`\s+`) for i := 0; i < b.N; i++ { - _ = sep.Split(s, -1) + _ = sep.Split(stringWithSpaces, -1) } } -func BenchmarkUtilSeparateBySpaceRegExp(b *testing.B) { - sep := regexp.MustCompile(`\s+`) - s := " Hello World! Enjoy" +func BenchmarkUtilSeparateByAS(b *testing.B) { for i := 0; i < b.N; i++ { - _ = sep.Split(s, -1) + _ = separateByAS(stringWithASKeyword) } } -func BenchmarkUtilSeparateByASRegExp(b *testing.B) { +func BenchmarkUtilRegExpSeparateByAS(b *testing.B) { sep := regexp.MustCompile(`(?i:\s+AS\s+)`) - s := "table.Name AS myTableAlias" for i := 0; i < b.N; i++ { - _ = sep.Split(s, -1) + _ = sep.Split(stringWithASKeyword, -1) } } diff --git a/util/sqlgen/value.go b/util/sqlgen/value.go index aae5f25efa97785090a2870cadd486a1d4bb5f86..eb224e237078f3de0a42dff6a929746ee2251957 100644 --- a/util/sqlgen/value.go +++ b/util/sqlgen/value.go @@ -1,70 +1,119 @@ package sqlgen import ( + //"database/sql/driver" "fmt" + //"log" "strings" ) -type Values []Value +// Values represents an array of Value. +type Values struct { + Values []Fragment + hash string +} +// Value represents an escaped SQL value. type Value struct { - Value interface{} + V interface{} + hash string +} + +// NewValue creates and returns a Value. +func NewValue(v interface{}) *Value { + return &Value{V: v} +} + +// JoinValues creates and returns an array of values. +func JoinValues(v ...Fragment) *Values { + return &Values{Values: v} } -func (self Value) Hash() string { - switch t := self.Value.(type) { - case cc: - return `Value(` + t.Hash() + `)` - case string: - return `Value(` + t + `)` +// Hash returns a unique identifier. +func (v *Value) Hash() string { + if v.hash == "" { + switch t := v.V.(type) { + case Fragment: + v.hash = `Value(` + t.Hash() + `)` + case string: + v.hash = `Value(` + t + `)` + default: + v.hash = fmt.Sprintf(`Value(%v)`, v.V) + } } - return fmt.Sprintf(`Value(%v)`, self.Value) + return v.hash } -func (self Value) Compile(layout *Template) (compiled string) { +// Compile transforms the Value into an equivalent SQL representation. +func (v *Value) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(v); ok { + return z } - if raw, ok := self.Value.(Raw); ok { - compiled = raw.Raw + if raw, ok := v.V.(Raw); ok { + compiled = raw.Compile(layout) + } else if raw, ok := v.V.(Fragment); ok { + compiled = raw.Compile(layout) } else { - compiled = mustParse(layout.ValueQuote, Raw{fmt.Sprintf(`%v`, self.Value)}) + compiled = mustParse(layout.ValueQuote, RawValue(fmt.Sprintf(`%v`, v.V))) } - layout.Write(self, compiled) + layout.Write(v, compiled) return } -func (self Values) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `Values(` + strings.Join(hash, `,`) + `)` +/* +func (v *Value) Scan(src interface{}) error { + log.Println("Scan(", src, ") on", v.V) + return nil } -func (self Values) Compile(layout *Template) (compiled string) { +func (v *Value) Value() (driver.Value, error) { + log.Println("Value() on", v.V) + return v.V, nil +} +*/ + +// Hash returns a unique identifier. +func (vs *Values) Hash() string { + if vs.hash == "" { + hash := make([]string, len(vs.Values)) + for i := range vs.Values { + hash[i] = vs.Values[i].Hash() + } + vs.hash = `Values(` + strings.Join(hash, `,`) + `)` + } + return vs.hash +} - if c, ok := layout.Read(self); ok { +// Compile transforms the Values into an equivalent SQL representation. +func (vs *Values) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(vs); ok { return c } - l := len(self) - + l := len(vs.Values) if l > 0 { chunks := make([]string, 0, l) - for i := 0; i < l; i++ { - chunks = append(chunks, self[i].Compile(layout)) + chunks = append(chunks, vs.Values[i].Compile(layout)) } - compiled = strings.Join(chunks, layout.ValueSeparator) } + layout.Write(vs, compiled) + return +} - layout.Write(self, compiled) +/* +func (vs Values) Scan(src interface{}) error { + log.Println("Values.Scan(", src, ")") + return nil +} - return +func (vs Values) Value() (driver.Value, error) { + log.Println("Values.Value()") + return vs, nil } +*/ diff --git a/util/sqlgen/value_test.go b/util/sqlgen/value_test.go index 1c5b5e30b4ce1406b2f3e49eff278c2abb9addd3..8c621700d7a2327fa65cb687468d52e3258a0218 100644 --- a/util/sqlgen/value_test.go +++ b/util/sqlgen/value_test.go @@ -6,9 +6,9 @@ import ( func TestValue(t *testing.T) { var s, e string - var val Value + var val *Value - val = Value{1} + val = NewValue(1) s = val.Compile(defaultTemplate) e = `'1'` @@ -17,7 +17,7 @@ func TestValue(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } - val = Value{Raw{"NOW()"}} + val = NewValue(&Raw{Value: "NOW()"}) s = val.Compile(defaultTemplate) e = `NOW()` @@ -29,13 +29,12 @@ func TestValue(t *testing.T) { func TestValues(t *testing.T) { var s, e string - var val Values - val = Values{ - Value{Raw{"1"}}, - Value{Raw{"2"}}, - Value{"3"}, - } + val := JoinValues( + &Value{V: &Raw{Value: "1"}}, + &Value{V: &Raw{Value: "2"}}, + &Value{V: "3"}, + ) s = val.Compile(defaultTemplate) e = `1, 2, '3'` @@ -44,3 +43,57 @@ func TestValues(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewValue("a") + } +} + +func BenchmarkValueHash(b *testing.B) { + v := NewValue("a") + for i := 0; i < b.N; i++ { + _ = v.Hash() + } +} + +func BenchmarkValueCompile(b *testing.B) { + v := NewValue("a") + for i := 0; i < b.N; i++ { + _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + v := NewValue("a") + _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinValues(NewValue("a"), NewValue("b")) + } +} + +func BenchmarkValuesHash(b *testing.B) { + vs := JoinValues(NewValue("a"), NewValue("b")) + for i := 0; i < b.N; i++ { + _ = vs.Hash() + } +} + +func BenchmarkValuesCompile(b *testing.B) { + vs := JoinValues(NewValue("a"), NewValue("b")) + for i := 0; i < b.N; i++ { + _ = vs.Compile(defaultTemplate) + } +} + +func BenchmarkValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + vs := JoinValues(NewValue("a"), NewValue("b")) + _ = vs.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/where.go b/util/sqlgen/where.go index bcc9b677649258ac887fea097f5cab3b8f3eb3d3..64ea9f106424231e3c38dd73c405cbb26a0feb9e 100644 --- a/util/sqlgen/where.go +++ b/util/sqlgen/where.go @@ -1,84 +1,110 @@ package sqlgen import ( + "fmt" "strings" ) -type ( - Or []cc - And []cc - Where []cc -) +// Or represents an SQL OR operator. +type Or Where + +// And represents an SQL AND operator. +type And Where + +// Where represents an SQL WHERE clause. +type Where struct { + Conditions []Fragment + hash string +} type conds struct { Conds string } -func (self Or) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `Or(` + strings.Join(hash, `,`) + `)` +// WhereConditions creates and retuens a new Where. +func WhereConditions(conditions ...Fragment) *Where { + return &Where{Conditions: conditions} } -func (self Or) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c - } +// JoinWithOr creates and returns a new Or. +func JoinWithOr(conditions ...Fragment) *Or { + return &Or{Conditions: conditions} +} - compiled = groupCondition(layout, self, mustParse(layout.ClauseOperator, layout.OrKeyword)) +// JoinWithAnd creates and returns a new And. +func JoinWithAnd(conditions ...Fragment) *And { + return &And{Conditions: conditions} +} - layout.Write(self, compiled) +// Hash returns a unique identifier. +func (w *Where) Hash() string { + if w.hash == "" { + hash := make([]string, len(w.Conditions)) + for i := range w.Conditions { + hash[i] = w.Conditions[i].Hash() + } + w.hash = fmt.Sprintf(`Where{%s}`, strings.Join(hash, `, `)) + } + return w.hash +} - return +// Hash returns a unique identifier. +func (o *Or) Hash() string { + w := Where(*o) + return `Or(` + w.Hash() + `)` } -func (self And) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `And(` + strings.Join(hash, `,`) + `)` +// Hash returns a unique identifier. +func (a *And) Hash() string { + w := Where(*a) + return `Or(` + w.Hash() + `)` } -func (self And) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c +// Compile transforms the Or into an equivalent SQL representation. +func (o *Or) Compile(layout *Template) (compiled string) { + + if z, ok := layout.Read(o); ok { + return z } - compiled = groupCondition(layout, self, mustParse(layout.ClauseOperator, layout.AndKeyword)) + compiled = groupCondition(layout, o.Conditions, mustParse(layout.ClauseOperator, layout.OrKeyword)) - layout.Write(self, compiled) + layout.Write(o, compiled) return } -func (self Where) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +// Compile transforms the And into an equivalent SQL representation. +func (a *And) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(a); ok { + return c } - return `Where(` + strings.Join(hash, `,`) + `)` + + compiled = groupCondition(layout, a.Conditions, mustParse(layout.ClauseOperator, layout.AndKeyword)) + + layout.Write(a, compiled) + + return } -func (self Where) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { +// Compile transforms the Where into an equivalent SQL representation. +func (w *Where) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(w); ok { return c } - grouped := groupCondition(layout, self, mustParse(layout.ClauseOperator, layout.AndKeyword)) + grouped := groupCondition(layout, w.Conditions, mustParse(layout.ClauseOperator, layout.AndKeyword)) if grouped != "" { compiled = mustParse(layout.WhereLayout, conds{grouped}) } - layout.Write(self, compiled) + layout.Write(w, compiled) return } -func groupCondition(layout *Template, terms []cc, joinKeyword string) string { +func groupCondition(layout *Template, terms []Fragment, joinKeyword string) string { l := len(terms) chunks := make([]string, 0, l) diff --git a/util/sqlgen/where_test.go b/util/sqlgen/where_test.go index 111f79b41d44ed2c8f0a837d91b4c2f31f44b342..c4b2a182ac7e1cadcfdc216f029f529c4864290f 100644 --- a/util/sqlgen/where_test.go +++ b/util/sqlgen/where_test.go @@ -6,13 +6,12 @@ import ( func TestWhereAnd(t *testing.T) { var s, e string - var and And - and = And{ - ColumnValue{Column{"id"}, ">", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "<", Value{Raw{"99"}}}, - ColumnValue{Column{"name"}, "=", Value{"John"}}, - } + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + ) s = and.Compile(defaultTemplate) e = `("id" > 8 AND "id" < 99 AND "name" = 'John')` @@ -24,12 +23,11 @@ func TestWhereAnd(t *testing.T) { func TestWhereOr(t *testing.T) { var s, e string - var or Or - or = Or{ - ColumnValue{Column{"id"}, "=", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "=", Value{Raw{"99"}}}, - } + or := JoinWithOr( + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "99"})}, + ) s = or.Compile(defaultTemplate) e = `("id" = 8 OR "id" = 99)` @@ -41,17 +39,16 @@ func TestWhereOr(t *testing.T) { func TestWhereAndOr(t *testing.T) { var s, e string - var and And - - and = And{ - ColumnValue{Column{"id"}, ">", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "<", Value{Raw{"99"}}}, - ColumnValue{Column{"name"}, "=", Value{"John"}}, - Or{ - ColumnValue{Column{"last_name"}, "=", Value{"Smith"}}, - ColumnValue{Column{"last_name"}, "=", Value{"Reyes"}}, - }, - } + + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + ) s = and.Compile(defaultTemplate) e = `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))` @@ -63,24 +60,23 @@ func TestWhereAndOr(t *testing.T) { func TestWhereAndRawOrAnd(t *testing.T) { var s, e string - var where Where - - where = Where{ - And{ - ColumnValue{Column{"id"}, ">", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "<", Value{Raw{"99"}}}, - }, - ColumnValue{Column{"name"}, "=", Value{"John"}}, - Raw{"city_id = 728"}, - Or{ - ColumnValue{Column{"last_name"}, "=", Value{"Smith"}}, - ColumnValue{Column{"last_name"}, "=", Value{"Reyes"}}, - }, - And{ - ColumnValue{Column{"age"}, ">", Value{Raw{"18"}}}, - ColumnValue{Column{"age"}, "<", Value{Raw{"41"}}}, - }, - } + + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) s = trim(where.Compile(defaultTemplate)) e = `WHERE (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))` @@ -89,3 +85,29 @@ func TestWhereAndRawOrAnd(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkWhere(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + } +} + +func BenchmarkCompileWhere(b *testing.B) { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + for i := 0; i < b.N; i++ { + w.Compile(defaultTemplate) + } +} + +func BenchmarkCompileWhereNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + w.Compile(defaultTemplate) + } +} diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go new file mode 100644 index 0000000000000000000000000000000000000000..96f2c552c0362aaf476588492190e03d4eee9fb8 --- /dev/null +++ b/util/sqlutil/convert.go @@ -0,0 +1,201 @@ +package sqlutil + +import ( + "fmt" + "reflect" + "strings" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" +) + +var ( + sqlPlaceholder = sqlgen.RawValue(`?`) + sqlNull = sqlgen.RawValue(`NULL`) +) + +type TemplateWithUtils struct { + *sqlgen.Template +} + +func NewTemplateWithUtils(template *sqlgen.Template) *TemplateWithUtils { + return &TemplateWithUtils{template} +} + +// ToWhereWithArguments converts the given db.Cond parameters into a sqlgen.Where +// value. +func (tu *TemplateWithUtils) ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + for i := range t { + w, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case db.And: + var op sqlgen.And + for i := range t { + k, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + op.Conditions = append(op.Conditions, k.Conditions...) + } + where.Conditions = append(where.Conditions, &op) + return + case db.Or: + var op sqlgen.Or + for i := range t { + w, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + op.Conditions = append(op.Conditions, w.Conditions...) + } + where.Conditions = append(where.Conditions, &op) + return + case db.Raw: + if s, ok := t.Value.(string); ok { + where.Conditions = append(where.Conditions, sqlgen.RawValue(s)) + } + return + case db.Cond: + cv, v := tu.ToColumnValues(t) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + return + case db.Constrainer: + cv, v := tu.ToColumnValues(t.Constraint()) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + return + } + + panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), term)) +} + +// ToInterfaceArguments converts the given value into an array of interfaces. +func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []interface{}) { + if value == nil { + return nil + } + + v := reflect.ValueOf(value) + + switch v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() + } + + return args + } + return nil + default: + args = []interface{}{value} + } + + return args +} + +// ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct. +func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { + + args = []interface{}{} + + for column, value := range cond { + columnValue := sqlgen.ColumnValue{} + + // Guessing operator from input, or using a default one. + column := strings.TrimSpace(column) + chunks := strings.SplitN(column, ` `, 2) + + columnValue.Column = sqlgen.ColumnWithName(chunks[0]) + + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } else { + columnValue.Operator = tu.DefaultOperator + } + + switch value := value.(type) { + case db.Func: + v := tu.ToInterfaceArguments(value.Args) + columnValue.Operator = value.Name + + if v == nil { + // A function with no arguments. + columnValue.Value = sqlgen.RawValue(`()`) + } else { + // A function with one or more arguments. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } + + args = append(args, v...) + default: + v := tu.ToInterfaceArguments(value) + + l := len(v) + if v == nil || l == 0 { + // Nil value given. + columnValue.Value = sqlNull + } else { + if l > 1 { + // Array value given. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } else { + // Single value given. + columnValue.Value = sqlPlaceholder + } + args = append(args, v...) + } + } + + ToColumnValues.ColumnValues = append(ToColumnValues.ColumnValues, &columnValue) + } + + return ToColumnValues, args +} + +// ToColumnsValuesAndArguments maps the given columnNames and columnValues into +// sqlgen's Columns and Values, it also extracts and returns query arguments. +func (tu *TemplateWithUtils) ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(sqlgen.Columns) + + columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i])) + } + + values := new(sqlgen.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]sqlgen.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *sqlgen.Value: + // Adding value. + values.Values = append(values.Values, v) + case sqlgen.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} diff --git a/util/sqlutil/debug.go b/util/sqlutil/debug.go index 08d8ebe9f9f03bad1ec16b428edcdf515f80e423..eef4baf6bae4f3446181734ef3117499806e7cc2 100644 --- a/util/sqlutil/debug.go +++ b/util/sqlutil/debug.go @@ -24,7 +24,10 @@ package sqlutil import ( "fmt" "log" + "os" "strings" + + "upper.io/v2/db" ) // Debug is used for printing SQL queries and arguments. @@ -59,3 +62,17 @@ func (d *Debug) Print() { log.Printf("\n\t%s\n\n", strings.Join(s, "\n\t")) } + +func IsDebugEnabled() bool { + if os.Getenv(db.EnvEnableDebug) != "" { + return true + } + return false +} + +func Log(query string, args []interface{}, err error, start int64, end int64) { + if IsDebugEnabled() { + d := Debug{query, args, err, start, end} + d.Print() + } +} diff --git a/util/sqlutil/fetch.go b/util/sqlutil/fetch.go index c11438de894ee71d47f9266777667496959ef125..9d1e1bc2b317247ee7da8aab2e1a2f76c0ef5139 100644 --- a/util/sqlutil/fetch.go +++ b/util/sqlutil/fetch.go @@ -22,17 +22,17 @@ package sqlutil import ( - "database/sql" + "encoding/json" "reflect" - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util" + "github.com/jmoiron/sqlx" + "github.com/jmoiron/sqlx/reflectx" + "upper.io/v2/db" ) -// FetchRow receives a *sql.Rows value and tries to map all the rows into a +// FetchRow receives a *sqlx.Rows value and tries to map all the rows into a // single struct given by the pointer `dst`. -func FetchRow(rows *sql.Rows, dst interface{}) error { +func FetchRow(rows *sqlx.Rows, dst interface{}) error { var columns []string var err error @@ -59,23 +59,29 @@ func FetchRow(rows *sql.Rows, dst interface{}) error { return db.ErrNoMoreRows } - item, err := fetchResult(itemV.Type(), rows, columns) + itemT := itemV.Type() + item, err := fetchResult(itemT, rows, columns) if err != nil { return err } - itemV.Set(reflect.Indirect(item)) + if itemT.Kind() == reflect.Ptr { + itemV.Set(item) + } else { + itemV.Set(reflect.Indirect(item)) + } return nil } -// FetchRows receives a *sql.Rows value and tries to map all the rows into a +// FetchRows receives a *sqlx.Rows value and tries to map all the rows into a // slice of structs given by the pointer `dst`. -func FetchRows(rows *sql.Rows, dst interface{}) error { - var columns []string +func FetchRows(rows *sqlx.Rows, dst interface{}) error { var err error + defer rows.Close() + // Destination. dstv := reflect.ValueOf(dst) @@ -91,6 +97,7 @@ func FetchRows(rows *sql.Rows, dst interface{}) error { return db.ErrExpectingSliceMapStruct } + var columns []string if columns, err = rows.Columns(); err != nil { return err } @@ -101,199 +108,176 @@ func FetchRows(rows *sql.Rows, dst interface{}) error { reset(dst) for rows.Next() { - item, err := fetchResult(itemT, rows, columns) - if err != nil { return err } - - slicev = reflect.Append(slicev, reflect.Indirect(item)) + if itemT.Kind() == reflect.Ptr { + slicev = reflect.Append(slicev, item) + } else { + slicev = reflect.Append(slicev, reflect.Indirect(item)) + } } - rows.Close() - dstv.Elem().Set(slicev) return nil } -// indirect function taken from encoding/json/decode.go -// -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -func indirect(v reflect.Value, decodingNull bool) (db.Unmarshaler, reflect.Value) { - // If v is a named type and is addressable, - // start with its address, so that if the type has pointer methods, - // we find them. - if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { - v = v.Addr() - } - for { - // Load value from interface, but only if the result will be - // usefully addressable. - if v.Kind() == reflect.Interface && !v.IsNil() { - e := v.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { - v = e - continue - } - } - - if v.Kind() != reflect.Ptr { - break - } - - if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { - break - } - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - if v.Type().NumMethod() > 0 { - if u, ok := v.Interface().(db.Unmarshaler); ok { - return u, reflect.Value{} - } - } - v = v.Elem() - } - return nil, v -} - -func fetchResult(itemT reflect.Type, rows *sql.Rows, columns []string) (reflect.Value, error) { +func fetchResult(itemT reflect.Type, rows *sqlx.Rows, columns []string) (reflect.Value, error) { var item reflect.Value var err error - switch itemT.Kind() { + objT := itemT + + switch objT.Kind() { case reflect.Map: - item = reflect.MakeMap(itemT) + item = reflect.MakeMap(objT) case reflect.Struct: - item = reflect.New(itemT) + item = reflect.New(objT) + case reflect.Ptr: + objT = itemT.Elem() + if objT.Kind() != reflect.Struct { + return item, db.ErrExpectingMapOrStruct + } + item = reflect.New(objT) default: return item, db.ErrExpectingMapOrStruct } - expecting := len(columns) + switch objT.Kind() { - // Allocating results. - values := make([]*sql.RawBytes, expecting) - scanArgs := make([]interface{}, expecting) + case reflect.Struct: - for i := range columns { - scanArgs[i] = &values[i] - } + values := make([]interface{}, len(columns)) + typeMap := rows.Mapper.TypeMap(itemT) + fieldMap := typeMap.Names + wrappedValues := map[*reflectx.Field]interface{}{} - if err = rows.Scan(scanArgs...); err != nil { - return item, err - } + for i, k := range columns { + fi, ok := fieldMap[k] + if !ok { + values[i] = new(interface{}) + continue + } + + // TODO: refactor into a nice pattern + if _, ok := fi.Options["stringarray"]; ok { + values[i] = &[]byte{} + wrappedValues[fi] = values[i] + } else if _, ok := fi.Options["int64array"]; ok { + values[i] = &[]byte{} + wrappedValues[fi] = values[i] + } else if _, ok := fi.Options["jsonb"]; ok { + values[i] = &[]byte{} + wrappedValues[fi] = values[i] + } else { + f := reflectx.FieldByIndexes(item, fi.Index) + values[i] = f.Addr().Interface() + } + + if u, ok := values[i].(db.Unmarshaler); ok { + values[i] = scanner{u} + } + } - // Range over row values. - for i, value := range values { + // Scanner - for reads + // Valuer - for writes - if value != nil { - // Real column name - column := columns[i] + // OptionTypes + // - before/after scan + // - before/after valuer.. - // Value as string. - svalue := string(*value) + if err = rows.Scan(values...); err != nil { + return item, err + } - var cv reflect.Value + // TODO: move this stuff out of here.. find a nice pattern + for fi, v := range wrappedValues { + var opt string + if _, ok := fi.Options["stringarray"]; ok { + opt = "stringarray" + } else if _, ok := fi.Options["int64array"]; ok { + opt = "int64array" + } else if _, ok := fi.Options["jsonb"]; ok { + opt = "jsonb" + } - v, _ := to.Convert(svalue, reflect.String) - cv = reflect.ValueOf(v) + b := v.(*[]byte) - switch itemT.Kind() { - // Destination is a map. - case reflect.Map: - if cv.Type() != itemT.Elem() { - if itemT.Elem().Kind() == reflect.Interface { - cv, _ = util.StringToType(svalue, cv.Type()) - } else { - cv, _ = util.StringToType(svalue, itemT.Elem()) - } + f := reflectx.FieldByIndexesReadOnly(item, fi.Index) + + switch opt { + case "stringarray": + v := StringArray{} + err := v.Scan(*b) + if err != nil { + return item, err + } + f.Set(reflect.ValueOf(v)) + case "int64array": + v := Int64Array{} + err := v.Scan(*b) + if err != nil { + return item, err } - if cv.IsValid() { - item.SetMapIndex(reflect.ValueOf(column), cv) + f.Set(reflect.ValueOf(v)) + case "jsonb": + if len(*b) == 0 { + continue } - // Destionation is a struct. - case reflect.Struct: - index := util.GetStructFieldIndex(itemT, column) + var vv reflect.Value + t := reflect.PtrTo(f.Type()) + + switch t.Kind() { + case reflect.Map: + vv = reflect.MakeMap(t) + case reflect.Slice: + vv = reflect.MakeSlice(t, 0, 0) + default: + vv = reflect.New(t) + } + + err := json.Unmarshal(*b, vv.Interface()) + if err != nil { + return item, err + } - if index == nil { + vv = vv.Elem().Elem() + + if !vv.IsValid() || (vv.Kind() == reflect.Ptr && vv.IsNil()) { continue - } else { - - // Destination field. - destf := item.Elem().FieldByIndex(index) - - if destf.IsValid() { - - if cv.Type() != destf.Type() { - - if destf.Type().Kind() != reflect.Interface { - - switch destf.Type() { - case nullFloat64Type: - nullFloat64 := sql.NullFloat64{} - if svalue != `` { - nullFloat64.Scan(svalue) - } - cv = reflect.ValueOf(nullFloat64) - case nullInt64Type: - nullInt64 := sql.NullInt64{} - if svalue != `` { - nullInt64.Scan(svalue) - } - cv = reflect.ValueOf(nullInt64) - case nullBoolType: - nullBool := sql.NullBool{} - if svalue != `` { - nullBool.Scan(svalue) - } - cv = reflect.ValueOf(nullBool) - case nullStringType: - nullString := sql.NullString{} - nullString.Scan(svalue) - cv = reflect.ValueOf(nullString) - default: - var decodingNull bool - - if svalue == "" { - decodingNull = true - } - - u, _ := indirect(destf, decodingNull) - - if u != nil { - u.UnmarshalDB(svalue) - - if destf.Kind() == reflect.Interface || destf.Kind() == reflect.Ptr { - cv = reflect.ValueOf(u) - } else { - cv = reflect.ValueOf(u).Elem() - } - - } else { - cv, _ = util.StringToType(svalue, destf.Type()) - } - - } - } - - } - - // Copying value. - if cv.IsValid() { - destf.Set(cv) - } - - } } + f.Set(vv) + } + } + + case reflect.Map: + + columns, err := rows.Columns() + if err != nil { + return item, err + } + + values := make([]interface{}, len(columns)) + for i := range values { + if itemT.Elem().Kind() == reflect.Interface { + values[i] = new(interface{}) + } else { + values[i] = reflect.New(itemT.Elem()).Interface() } } + + if err = rows.Scan(values...); err != nil { + return item, err + } + + for i, column := range columns { + item.SetMapIndex(reflect.ValueOf(column), reflect.Indirect(reflect.ValueOf(values[i]))) + } + } return item, nil diff --git a/util/sqlutil/main.go b/util/sqlutil/main.go deleted file mode 100644 index 3c8bf46763acdc27aeacfd67dd736e50c032cfe1..0000000000000000000000000000000000000000 --- a/util/sqlutil/main.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package sqlutil - -import ( - "database/sql" - "reflect" - "regexp" - - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util" -) - -var ( - reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) -) - -var ( - nullInt64Type = reflect.TypeOf(sql.NullInt64{}) - nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) - nullBoolType = reflect.TypeOf(sql.NullBool{}) - nullStringType = reflect.TypeOf(sql.NullString{}) -) - -// T type is commonly used by adapters to map database/sql values to Go values -// using FieldValues() -type T struct { - Columns []string -} - -func (t *T) columnLike(s string) string { - for _, name := range t.Columns { - if util.NormalizeColumn(s) == util.NormalizeColumn(name) { - return name - } - } - return s -} - -func marshal(v interface{}) (interface{}, error) { - m, isM := v.(db.Marshaler) - - if isM { - var err error - if v, err = m.MarshalDB(); err != nil { - return nil, err - } - } - - return v, nil -} - -// FieldValues accepts a map or a struct and splits them into an array of -// columns and values. -func (t *T) FieldValues(item interface{}, convertFn func(interface{}) interface{}) ([]string, []interface{}, error) { - - fields := []string{} - values := []interface{}{} - - itemV := reflect.ValueOf(item) - itemT := itemV.Type() - - if itemT.Kind() == reflect.Ptr { - // Single derefence. Just in case user passed a pointer to struct instead of a struct. - item = itemV.Elem().Interface() - itemV = reflect.ValueOf(item) - itemT = itemV.Type() - } - - switch itemT.Kind() { - - case reflect.Struct: - - nfields := itemV.NumField() - - values = make([]interface{}, 0, nfields) - fields = make([]string, 0, nfields) - - for i := 0; i < nfields; i++ { - - field := itemT.Field(i) - - if field.PkgPath != `` { - // Field is unexported. - continue - } - - if field.Anonymous { - // It's an anonymous field. Let's skip it unless it has an explicit - // `db` tag. - if field.Tag.Get(`db`) == `` { - continue - } - } - - // Field options. - fieldName, fieldOptions := util.ParseTag(field.Tag.Get(`db`)) - - // Deprecated `field` tag. - if deprecatedField := field.Tag.Get(`field`); deprecatedField != `` { - fieldName = deprecatedField - } - - // Deprecated `omitempty` tag. - if deprecatedOmitEmpty := field.Tag.Get(`omitempty`); deprecatedOmitEmpty != `` { - fieldOptions[`omitempty`] = true - } - - // Deprecated `inline` tag. - if deprecatedInline := field.Tag.Get(`inline`); deprecatedInline != `` { - fieldOptions[`inline`] = true - } - - // Skipping field - if fieldName == `-` { - continue - } - - // Trying to match field name. - - // Explicit JSON or BSON options. - if fieldName == `` && fieldOptions[`bson`] { - // Using name from the BSON tag. - fieldName, _ = util.ParseTag(field.Tag.Get(`bson`)) - } - - if fieldName == `` && fieldOptions[`bson`] { - // Using name from the JSON tag. - fieldName, _ = util.ParseTag(field.Tag.Get(`bson`)) - } - - // Still don't have a match? try to match againt JSON. - if fieldName == `` { - fieldName, _ = util.ParseTag(field.Tag.Get(`json`)) - } - - // Still don't have a match? try to match againt BSON. - if fieldName == `` { - fieldName, _ = util.ParseTag(field.Tag.Get(`bson`)) - } - - // Nothing works, trying to match by name. - if fieldName == `` { - fieldName = t.columnLike(field.Name) - } - - // Processing tag options. - value := itemV.Field(i).Interface() - - if fieldOptions[`omitempty`] == true { - zero := reflect.Zero(reflect.TypeOf(value)).Interface() - if value == zero { - continue - } - } - - if fieldOptions[`inline`] == true { - infields, invalues, inerr := t.FieldValues(value, convertFn) - if inerr != nil { - return nil, nil, inerr - } - fields = append(fields, infields...) - values = append(values, invalues...) - } else { - fields = append(fields, fieldName) - v, err := marshal(convertFn(value)) - - if err != nil { - return nil, nil, err - } - - values = append(values, v) - } - } - case reflect.Map: - nfields := itemV.Len() - values = make([]interface{}, nfields) - fields = make([]string, nfields) - mkeys := itemV.MapKeys() - - for i, keyV := range mkeys { - valv := itemV.MapIndex(keyV) - fields[i] = t.columnLike(to.String(keyV.Interface())) - - v, err := marshal(convertFn(valv.Interface())) - - if err != nil { - return nil, nil, err - } - - values[i] = v - } - default: - return nil, nil, db.ErrExpectingMapOrStruct - } - - return fields, values, nil -} - -func reset(data interface{}) error { - // Resetting element. - v := reflect.ValueOf(data).Elem() - t := v.Type() - z := reflect.Zero(t) - v.Set(z) - return nil -} diff --git a/postgresql/result.go b/util/sqlutil/result/result.go similarity index 53% rename from postgresql/result.go rename to util/sqlutil/result/result.go index d1fbc640b95194f26f33bf9c4b3ebd06d7a0ea13..2426cead30b2d4373f8ed71ac35b634a1d6efc34 100644 --- a/postgresql/result.go +++ b/util/sqlutil/result/result.go @@ -19,25 +19,29 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package postgresql +package result import ( - "database/sql" "fmt" "strings" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" + "github.com/jmoiron/sqlx" + "upper.io/v2/db" + "upper.io/v2/db/util/sqlgen" + "upper.io/v2/db/util/sqlutil" +) + +var ( + sqlPlaceholder = sqlgen.RawValue(`?`) ) type counter struct { Total uint64 `db:"_t"` } -type result struct { - table *table - cursor *sql.Rows // This is the main query cursor. It starts as a nil value. +type Result struct { + table DataProvider + cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. limit sqlgen.Limit offset sqlgen.Offset columns sqlgen.Columns @@ -45,76 +49,73 @@ type result struct { orderBy sqlgen.OrderBy groupBy sqlgen.GroupBy arguments []interface{} + template *sqlutil.TemplateWithUtils +} + +// NewResult creates and results a new result set on the given table, this set +// is limited by the given sqlgen.Where conditions. +func NewResult(template *sqlutil.TemplateWithUtils, p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { + return &Result{ + table: p, + where: where, + arguments: arguments, + template: template, + } } // Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { +func (r *Result) setCursor() error { var err error // We need a cursor, if the cursor does not exists yet then we create one. if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, + r.cursor, err = r.table.Query(sqlgen.Statement{ + Type: sqlgen.Select, + Table: sqlgen.TableWithName(r.table.Name()), + Columns: &r.columns, Limit: r.limit, Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, + Where: &r.where, + OrderBy: &r.orderBy, + GroupBy: &r.groupBy, }, r.arguments...) } return err } // Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) +func (r *Result) Where(terms ...interface{}) db.Result { + r.where, r.arguments = r.template.ToWhereWithArguments(terms) return r } // Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { +func (r *Result) Limit(n uint) db.Result { r.limit = sqlgen.Limit(n) return r } // Determines how many documents will be skipped before starting to grab // results. -func (r *result) Skip(n uint) db.Result { +func (r *Result) Skip(n uint) db.Result { r.offset = sqlgen.Offset(n) return r } // Used to group results that have the same value in the same column or // columns. -func (r *result) Group(fields ...interface{}) db.Result { +func (r *Result) Group(fields ...interface{}) db.Result { + var columns []sqlgen.Fragment - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - groupByColumns = append(groupByColumns, sqlgen.Column{sqlgen.Raw{s}}) - case db.Raw: - groupByColumns = append(groupByColumns, sqlgen.Column{sqlgen.Raw{fmt.Sprintf("%v", value.Value)}}) - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) + for i := range fields { + switch v := fields[i].(type) { + case string: + columns = append(columns, sqlgen.ColumnWithName(v)) + case sqlgen.Fragment: + columns = append(columns, v) } } - r.groupBy = groupByColumns + r.groupBy = *sqlgen.GroupByColumns(columns...) return r } @@ -122,54 +123,52 @@ func (r *result) Group(fields ...interface{}) db.Result { // Determines sorting of results according to the provided names. Fields may be // prefixed by - (minus) which means descending order, ascending order would be // used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { +func (r *Result) Sort(fields ...interface{}) db.Result { - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) + var sortColumns sqlgen.SortColumns - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn + for i := range fields { + var sort *sqlgen.SortColumn switch value := fields[i].(type) { case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)), + Order: sqlgen.Ascendent, } case string: if strings.HasPrefix(value, `-`) { // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value[1:]), + Order: sqlgen.Descendent, } } else { // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value), + Order: sqlgen.Ascendent, } } } - sortColumns = append(sortColumns, sort) + sortColumns.Columns = append(sortColumns.Columns, sort) } - r.orderBy.SortColumns = sortColumns + r.orderBy.SortColumns = &sortColumns return r } // Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { +func (r *Result) Select(fields ...interface{}) db.Result { - r.columns = make(sqlgen.Columns, 0, len(fields)) + r.columns = sqlgen.Columns{} - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column + for i := range fields { + var col sqlgen.Fragment switch value := fields[i].(type) { case db.Func: - v := interfaceArgs(value.Args) + v := r.template.ToInterfaceArguments(value.Args) var s string if len(v) == 0 { s = fmt.Sprintf(`%s()`, value.Name) @@ -180,20 +179,20 @@ func (r *result) Select(fields ...interface{}) db.Result { } s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) } - col = sqlgen.Column{sqlgen.Raw{s}} + col = sqlgen.RawValue(s) case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} + col = sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)) default: - col = sqlgen.Column{value} + col = sqlgen.ColumnWithName(fmt.Sprintf(`%v`, value)) } - r.columns = append(r.columns, col) + r.columns.Columns = append(r.columns.Columns, col) } return r } // Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { +func (r *Result) All(dst interface{}) error { var err error if r.cursor != nil { @@ -216,7 +215,7 @@ func (r *result) All(dst interface{}) error { } // Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { +func (r *Result) One(dst interface{}) error { var err error if r.cursor != nil { @@ -231,15 +230,14 @@ func (r *result) One(dst interface{}) error { } // Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) error { - err := r.setCursor() - if err != nil { +func (r *Result) Next(dst interface{}) (err error) { + + if err = r.setCursor(); err != nil { r.Close() return err } - err = sqlutil.FetchRow(r.cursor, dst) - if err != nil { + if err = sqlutil.FetchRow(r.cursor, dst); err != nil { r.Close() return err } @@ -248,46 +246,48 @@ func (r *result) Next(dst interface{}) error { } // Removes the matching items from the collection. -func (r *result) Remove() error { +func (r *Result) Remove() error { var err error - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, + + _, err = r.table.Exec(sqlgen.Statement{ + Type: sqlgen.Delete, + Table: sqlgen.TableWithName(r.table.Name()), + Where: &r.where, }, r.arguments...) + return err } // Updates matching items from the collection with values of the given map or // struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values, toInternal) +func (r *Result) Update(values interface{}) error { - total := len(ff) + ff, vv, err := r.table.FieldValues(values) + if err != nil { + return err + } - cvs := make(sqlgen.ColumnValues, 0, total) + cvs := new(sqlgen.ColumnValues) - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) + for i := range ff { + cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: r.template.AssignmentOperator, Value: sqlPlaceholder}) } vv = append(vv, r.arguments...) - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, + _, err = r.table.Exec(sqlgen.Statement{ + Type: sqlgen.Update, + Table: sqlgen.TableWithName(r.table.Name()), ColumnValues: cvs, - Where: r.where, + Where: &r.where, }, vv...) return err } // Closes the result set. -func (r *result) Close() error { - var err error +func (r *Result) Close() (err error) { if r.cursor != nil { err = r.cursor.Close() r.cursor = nil @@ -296,22 +296,21 @@ func (r *result) Close() error { } // Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { +func (r *Result) Count() (uint64, error) { var count counter - rows, err := r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, + row, err := r.table.QueryRow(sqlgen.Statement{ + Type: sqlgen.Count, + Table: sqlgen.TableWithName(r.table.Name()), + Where: &r.where, }, r.arguments...) if err != nil { return 0, err } - defer rows.Close() - - if err = sqlutil.FetchRow(rows, &count); err != nil { + err = row.Scan(&count.Total) + if err != nil { return 0, err } diff --git a/util/sqlutil/result/table.go b/util/sqlutil/result/table.go new file mode 100644 index 0000000000000000000000000000000000000000..0984adcd966fc94f3fc1f576688dc1093979f468 --- /dev/null +++ b/util/sqlutil/result/table.go @@ -0,0 +1,15 @@ +package result + +import ( + "database/sql" + "github.com/jmoiron/sqlx" + "upper.io/v2/db/util/sqlgen" +) + +type DataProvider interface { + Name() string + Query(sqlgen.Statement, ...interface{}) (*sqlx.Rows, error) + QueryRow(sqlgen.Statement, ...interface{}) (*sqlx.Row, error) + Exec(sqlgen.Statement, ...interface{}) (sql.Result, error) + FieldValues(interface{}) ([]string, []interface{}, error) +} diff --git a/util/sqlutil/scanner.go b/util/sqlutil/scanner.go new file mode 100644 index 0000000000000000000000000000000000000000..270a501c31fde98a29a0571b26e4ce2c76d10ca8 --- /dev/null +++ b/util/sqlutil/scanner.go @@ -0,0 +1,188 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sqlutil + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "strconv" + "strings" + + "upper.io/v2/db" +) + +type scanner struct { + v db.Unmarshaler +} + +func (u scanner) Scan(v interface{}) error { + return u.v.UnmarshalDB(v) +} + +var _ sql.Scanner = scanner{} + +//------ + +type JsonbType struct { + V interface{} +} + +func (j *JsonbType) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + v := JsonbType{} + if err := json.Unmarshal(b, &v.V); err != nil { + return err + } + *j = v + return nil +} + +func (j JsonbType) Value() (driver.Value, error) { + b, err := json.Marshal(j.V) + if err != nil { + return nil, err + } + return b, nil +} + +//------ + +type StringArray []string + +func (a *StringArray) Scan(src interface{}) error { + if src == nil { + *a = StringArray{} + return nil + } + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + if len(b) == 0 { + return nil + } + s := string(b)[1 : len(b)-1] + if s == "" { + return nil + } + results := strings.Split(s, ",") + *a = StringArray(results) + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedString(b, a[0]) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedString(b, a[i]) + } + + return append(b, '}'), nil + } + + return []byte{'{', '}'}, nil +} + +func appendArrayQuotedString(b []byte, v string) []byte { + b = append(b, '"') + for { + i := strings.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +//------ + +type Int64Array []int64 + +func (a *Int64Array) Scan(src interface{}) error { + if src == nil { + return nil + } + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + s := string(b)[1 : len(b)-1] + parts := strings.Split(s, ",") + results := make([]int64, 0) + for _, n := range parts { + i, err := strconv.ParseInt(n, 10, 64) + if err != nil { + return err + } + results = append(results, i) + } + *a = Int64Array(results) + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return append(b, '}'), nil + } + + return []byte{'{', '}'}, nil +} diff --git a/util/sqlutil/sqlutil.go b/util/sqlutil/sqlutil.go new file mode 100644 index 0000000000000000000000000000000000000000..9b3ed2b4e90f4b534938c22a08eba01d3bd2d692 --- /dev/null +++ b/util/sqlutil/sqlutil.go @@ -0,0 +1,191 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sqlutil + +import ( + "database/sql" + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/jmoiron/sqlx/reflectx" + "upper.io/v2/db" +) + +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) + reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`) +) + +var ( + nullInt64Type = reflect.TypeOf(sql.NullInt64{}) + nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) + nullBoolType = reflect.TypeOf(sql.NullBool{}) + nullStringType = reflect.TypeOf(sql.NullString{}) +) + +// T type is commonly used by adapters to map database/sql values to Go values +// using FieldValues() +type T struct { + Columns []string + Mapper *reflectx.Mapper + Tables []string // Holds table names. +} + +func (t *T) columnLike(s string) string { + for _, name := range t.Columns { + if normalizeColumn(s) == normalizeColumn(name) { + return name + } + } + return s +} + +func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { + fields := []string{} + values := []interface{}{} + + itemV := reflect.ValueOf(item) + itemT := itemV.Type() + + if itemT.Kind() == reflect.Ptr { + // Single derefence. Just in case user passed a pointer to struct instead of a struct. + item = itemV.Elem().Interface() + itemV = reflect.ValueOf(item) + itemT = itemV.Type() + } + + switch itemT.Kind() { + + case reflect.Struct: + + fieldMap := t.Mapper.TypeMap(itemT).Names + nfields := len(fieldMap) + + values = make([]interface{}, 0, nfields) + fields = make([]string, 0, nfields) + + for _, fi := range fieldMap { + // log.Println("=>", fi.Name, fi.Options) + + fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) + if fld.Kind() == reflect.Ptr && fld.IsNil() { + continue + } + + var value interface{} + if _, ok := fi.Options["stringarray"]; ok { + value = StringArray(fld.Interface().([]string)) + } else if _, ok := fi.Options["int64array"]; ok { + value = Int64Array(fld.Interface().([]int64)) + } else if _, ok := fi.Options["jsonb"]; ok { + value = JsonbType{fld.Interface()} + } else { + value = fld.Interface() + } + + if _, ok := fi.Options["omitempty"]; ok { + if value == fi.Zero.Interface() { + continue + } + } + + // TODO: columnLike stuff...? + + fields = append(fields, fi.Name) + v, err := marshal(value) + if err != nil { + return nil, nil, err + } + values = append(values, v) + } + + case reflect.Map: + nfields := itemV.Len() + values = make([]interface{}, nfields) + fields = make([]string, nfields) + mkeys := itemV.MapKeys() + + for i, keyV := range mkeys { + valv := itemV.MapIndex(keyV) + fields[i] = t.columnLike(fmt.Sprintf("%v", keyV.Interface())) + + v, err := marshal(valv.Interface()) + if err != nil { + return nil, nil, err + } + + values[i] = v + } + + default: + return nil, nil, db.ErrExpectingMapOrStruct + } + + return fields, values, nil +} + +func marshal(v interface{}) (interface{}, error) { + if m, isMarshaler := v.(db.Marshaler); isMarshaler { + var err error + if v, err = m.MarshalDB(); err != nil { + return nil, err + } + } + return v, nil +} + +func reset(data interface{}) error { + // Resetting element. + v := reflect.ValueOf(data).Elem() + t := v.Type() + z := reflect.Zero(t) + v.Set(z) + return nil +} + +// normalizeColumn prepares a column for comparison against another column. +func normalizeColumn(s string) string { + return strings.ToLower(reColumnCompareExclude.ReplaceAllString(s, "")) +} + +// NewMapper creates a reflectx.Mapper +func NewMapper() *reflectx.Mapper { + return reflectx.NewMapper("db") +} + +// MainTableName returns the name of the first table. +func (t *T) MainTableName() string { + return t.NthTableName(0) +} + +// NthTableName returns the table name at index i. +func (t *T) NthTableName(i int) string { + if len(t.Tables) > i { + chunks := strings.SplitN(t.Tables[i], " ", 2) + if len(chunks) > 0 { + return chunks[0] + } + } + return "" +} diff --git a/mysql/tx.go b/util/sqlutil/tx/tx.go similarity index 78% rename from mysql/tx.go rename to util/sqlutil/tx/tx.go index 55a79efbbefbfe3bd8e51b19afd6a83fa079d8d3..533c54058f9f855dbc3004e8fad03f5f3022e0bb 100644 --- a/mysql/tx.go +++ b/util/sqlutil/tx/tx.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -19,26 +19,28 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package mysql +package sqltx import ( - "database/sql" + "github.com/jmoiron/sqlx" ) -type tx struct { - *source - sqlTx *sql.Tx - done bool +type Tx struct { + *sqlx.Tx + done bool } -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { +func New(tx *sqlx.Tx) *Tx { + return &Tx{Tx: tx} +} + +func (t *Tx) Done() bool { + return t.done +} + +func (t *Tx) Commit() (err error) { + if err = t.Tx.Commit(); err == nil { t.done = true } return err } - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/wrapper.go b/wrapper.go index 058cf6904532451ae4392250a7904f594efb8a39..8e3d01afd36bedfb014c674d6608c60bef8557eb 100644 --- a/wrapper.go +++ b/wrapper.go @@ -53,7 +53,7 @@ func Open(adapter string, conn ConnectionURL) (Database, error) { if ok == false { // Using panic instead of returning error because attemping to use an // adapter that does not exists will never result in success. - panic(fmt.Sprintf(`Open: Unknown adapter %s. (see: https://upper.io/db#database-adapters)`, adapter)) + panic(fmt.Sprintf(`Open: Unknown adapter %s. (see: https://upper.io/v2/db#database-adapters)`, adapter)) } // Creating a new connection everytime Open() is called.