diff --git a/.travis.yml b/.travis.yml index 54be8a906686bac0e111e2c544eca7e71d7013b2..e9af1942e1c2a687c8e2878f2eec0a61f6d9e6fa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,15 +13,16 @@ go: - 1.4.1 - 1.4.2 -env: GOARCH=amd64 TEST_HOST=127.0.0.1 +env: GOARCH=amd64 TEST_HOST=127.0.0.1 UPPERIO_DB_DEBUG=1 install: - - sudo apt-get install bzr - # - go get github.com/cznic/ql/ql # ql command line util. - # - go install github.com/cznic/ql/ql # ql command line util. + - sudo apt-get install -y bzr make + - go get github.com/cznic/ql/ql # ql command line util. + - go install github.com/cznic/ql/ql # ql command line util. - mkdir ../../../upper.io - ln -s $PWD ../../../upper.io/db - go get -t -d + - cd $GOPATH/src/github.com/jmoiron/sqlx && git checkout ptrs # Peter's branch. # - go get upper.io/db/mongo # - go get upper.io/db/mysql # - go get upper.io/db/postgresql @@ -35,10 +36,24 @@ before_script: - mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root mysql - cat mysql/_dumps/setup.sql | mysql -uroot - cat mysql/_dumps/structs.sql | mysql -uupperio -pupperio upperio_tests + - cat postgresql/_dumps/setup.sql | psql -U postgres - cat postgresql/_dumps/structs.sql | PGPASSWORD="upperio" psql -U upperio upperio_tests + - mongo upperio_tests --eval 'db.addUser("upperio", "upperio")' + + - (cd mysql/_dumps && make) + - (cd postgresql/_dumps && make) + - (cd sqlite/_dumps && make) + - (cd ql/_dumps && make) + # - cat ql/_dumps/structs.sql | $GOPATH/bin/ql -db ql/_dumps/test.db script: - - UPPERIO_DB_DEBUG=1 go test -test.v=1 + - go test upper.io/db/util/sqlgen -test.bench=. + - go test upper.io/db/postgresql -test.bench=. + - go test upper.io/db/mysql -test.bench=. + - go test upper.io/db/sqlite -test.bench=. + - go test upper.io/db/ql -test.bench=. + - go test upper.io/db/mongo -test.bench=. + - go test -test.v diff --git a/main.go b/db.go similarity index 98% rename from main.go rename to db.go index ec5659ac66cb77b628b1244c2d599d8bfc01c3eb..d28d94e2028e24099b40b275c18f23a618bc7815 100644 --- a/main.go +++ b/db.go @@ -100,6 +100,10 @@ type Func struct { // } type And []interface{} +func (a And) And(exp ...interface{}) And { + return append(a, exp...) +} + // Or is an array of interfaced that is used to join two or more expressions // under logical disjunction, it accepts `db.Cond{}`, `db.And{}`, `db.Raw{}` // and other `db.Or{}` values. @@ -113,6 +117,10 @@ type And []interface{} // } type Or []interface{} +func (o Or) Or(exp ...interface{}) Or { + return append(o, exp...) +} + // Raw holds chunks of data to be passed to the database without any filtering. // Use with care. // diff --git a/main_test.go b/db_test.go similarity index 100% rename from main_test.go rename to db_test.go diff --git a/error.go b/error.go index 631f435a5fea2aeab1dc489b92da7e8dafc61268..4110560b965ef194ec1cb14ce6995bac3242134b 100644 --- a/error.go +++ b/error.go @@ -47,7 +47,7 @@ var ( ErrUnsupportedDestination = errors.New(`Unsupported destination type.`) ErrUnsupportedType = errors.New(`This type does not support marshaling.`) ErrUnsupportedValue = errors.New(`This value does not support unmarshaling.`) - ErrUnknownConditionType = errors.New(`Arguments of type %s can't be used as constraints.`) + ErrUnknownConditionType = errors.New(`Arguments of type %T can't be used as constraints.`) ) // Deprecated but kept for backwards compatibility. See: https://github.com/upper/db/issues/18 diff --git a/mongo/database_test.go b/mongo/database_test.go index 1cd42677cd74b063d9b1a740e873031027a00bf1..fe426bb7f2a3edcd4bde08a68e905968dde222ee 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -24,7 +24,6 @@ package mongo import ( "errors" - "flag" "math/rand" "os" "reflect" @@ -52,7 +51,7 @@ var settings = ConnectionURL{ Password: password, } -var host = flag.String("host", "testserver.local", "Testing server address.") +var host string // Structure for testing conversions and datatypes. type testValuesStruct struct { @@ -127,8 +126,11 @@ func init() { time.Second * time.Duration(7331), } - flag.Parse() - settings.Address = db.ParseAddress(*host) + if host = os.Getenv("TEST_HOST"); host == "" { + host = "localhost" + } + + settings.Address = db.ParseAddress(host) } // Enabling outputting some information to stdout, useful for development. @@ -155,7 +157,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: username, Password: password, } @@ -169,7 +171,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: username, Password: "fail", } @@ -181,7 +183,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong database. wrongSettings = db.Settings{ Database: "fail", - Host: *host, + Host: host, User: username, Password: password, } @@ -193,7 +195,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: "fail", Password: password, } diff --git a/mysql/_dumps/Makefile b/mysql/_dumps/Makefile index 5e9b4f17c51413eec7ced1a74cea64aa3e70e253..fb1b4e28247a353505402c796fec0018420ad638 100644 --- a/mysql/_dumps/Makefile +++ b/mysql/_dumps/Makefile @@ -1,2 +1,4 @@ +TEST_HOST ?= 127.0.0.1 + all: - cat structs.sql | mysql -uupperio -pupperio upperio_tests -htestserver.local + cat structs.sql | mysql -uupperio -pupperio upperio_tests -h$(TEST_HOST) diff --git a/mysql/collection.go b/mysql/collection.go index fb1ee860a3665618821861c3c362a3f4ed001c90..d510642b72976d0a02435c5469ec4e54e08e6caf 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -23,233 +23,57 @@ package mysql import ( "database/sql" - "fmt" - "reflect" "strings" "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source - names []string -} - -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args + *database } -func interfaceArgs(value interface{}) (args []interface{}) { - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } +var _ = db.Collection(&table{}) - return args +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} - -func (c *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: c, - where: where, - arguments: arguments, - } - - return result -} - -func (c *table) tableN(i int) string { - if len(c.names) > i { - chunks := strings.SplitN(c.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. -func (c *table) Truncate() error { - - _, err := c.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{c.tableN(0)}, +// Truncate deletes all rows from the table. +func (t *table) Truncate() error { + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. -func (c *table) Append(item interface{}) (interface{}, error) { - +// Append inserts an item (map or struct) into the collection. +func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - cols, vals, err := c.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = c.source.getPrimaryKey(c.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -257,14 +81,14 @@ func (c *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{c.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = c.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -291,10 +115,10 @@ func (c *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -317,14 +141,15 @@ func (c *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. -func (c *table) Exists() bool { - if err := c.source.tableExists(c.names...); err != nil { +// Exists returns true if the collection exists. +func (t *table) Exists() bool { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } -func (c *table) Name() string { - return strings.Join(c.names, `, `) +// Name returns the name of the table or tables that form the collection. +func (t *table) Name() string { + return strings.Join(t.Tables, `, `) } diff --git a/mysql/database.go b/mysql/database.go index 97bf5038620bd07e153b01299785bc6017a753e4..df840b2221f2bc1dbf52cf8e0fec1841d3fc8b48 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -23,95 +23,55 @@ package mysql import ( "database/sql" - "os" "strings" "time" - // Importing MySQL driver. - _ "github.com/go-sql-driver/mysql" + _ "github.com/go-sql-driver/mysql" // MySQL driver. "github.com/jmoiron/sqlx" - "upper.io/cache" "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/tx" ) -const ( - // Adapter is the public name of the adapter. - Adapter = `mysql` -) - -var template *sqlgen.Template - var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema } -type columnSchemaT struct { - Name string `db:"column_name"` -} - -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name +type tx struct { + *sqltx.Tx + *database } -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} - -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sqlx.Tx - - if sqlTx, err = s.session.Beginx(); err != nil { - return nil, err - } - - if clone, err = s.clone(); err != nil { - return nil, err - } - - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -// Stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { - s.connURL = connURL - return s.Open() +type columnSchemaT struct { + Name string `db:"column_name"` } -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { var err error // Before db.ConnectionURL we used a unified db.Settings struct. This // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { + if settings, ok := d.connURL.(db.Settings); ok { // User is providing a db.Settings struct, let's translate it into a // ConnectionURL{}. @@ -141,64 +101,71 @@ func (s *source) Open() error { conn.Address = db.HostPort(settings.Host, uint(settings.Port)) } - // Replace original s.connURL - s.connURL = conn + // Replace original d.connURL + d.connURL = conn } - if s.session, err = sqlx.Open(`mysql`, s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`mysql`, d.connURL.String()); err != nil { return err } - s.session.Mapper = sqlutil.NewMapper() + d.session.Mapper = sqlutil.NewMapper() - if err = s.populateSchema(); err != nil { + if err = d.populateSchema(); err != nil { return err } return nil } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL +func (d *database) clone() (*database, error) { + src := &database{} + src.Setup(d.connURL) - if conn, err = ParseURL(s.connURL.String()); err != nil { - return err + if err := src.Open(); err != nil { + return nil, err } - conn.Database = database + return src, nil +} - s.connURL = conn +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() +} - return s.Open() +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { var err error if len(names) == 0 { return nil, db.ErrMissingCollectionName } - if s.tx != nil { - if s.tx.done { + if d.tx != nil { + if d.tx.Done() { return nil, sql.ErrTxDone } } - col := &table{ - source: s, - names: names, - } + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper for _, name := range names { chunks := strings.SplitN(name, ` `, 2) @@ -209,11 +176,11 @@ func (s *source) Collection(names ...string) (db.Collection, error) { tableName := chunks[0] - if err := s.tableExists(tableName); err != nil { + if err := d.tableExists(tableName); err != nil { return nil, err } - if col.Columns, err = s.tableColumns(tableName); err != nil { + if col.Columns, err = d.tableColumns(tableName); err != nil { return nil, err } } @@ -221,49 +188,35 @@ func (s *source) Collection(names ...string) (db.Collection, error) { return col, nil } -// Drops the currently active database. -func (s *source) Drop() error { - - _, err := s.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlDropDatabase, - Database: sqlgen.Database{s.schema.Name}, - }) - - return err -} - -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - tablesInSchema := len(s.schema.Tables) + tablesInSchema := len(d.schema.Tables) // Is schema already populated? if tablesInSchema > 0 { // Pulling table names from schema. - return s.schema.Tables, nil + return d.schema.Tables, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Table: sqlgen.Table{ - `information_schema.tables`, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_schema`}, - `=`, - sqlPlaceholder, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Table: sqlgen.TableWithName(`information_schema.tables`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, }, - }, + ), } // Executing statement. var rows *sqlx.Rows - if rows, err = s.doQuery(stmt, s.schema.Name); err != nil { + if rows, err = d.Query(stmt, d.schema.Name); err != nil { return nil, err } @@ -280,7 +233,7 @@ func (s *source) Collections() (collections []string, err error) { } // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.AddTable(name) // Adding table to collections array. collections = append(collections, name) @@ -289,25 +242,65 @@ func (s *source) Collections() (collections []string, err error) { return collections, nil } -func (s *source) clone() (*source, error) { - src := &source{} - src.Setup(s.connURL) +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL - if err := src.Open(); err != nil { - return nil, err + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err } - return src, nil + conn.Database = database + + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) + + return err +} + +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err + } + + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err } + + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -317,25 +310,26 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - res, err = s.session.Exec(query, args...) + res, err = d.session.Exec(query, args...) } return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -345,25 +339,26 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - rows, err = s.session.Queryx(query, args...) + rows, err = d.session.Queryx(query, args...) } return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -373,62 +368,57 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - row = s.session.QueryRowx(query, args...) + row = d.session.QueryRowx(query, args...) } return row, err } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} - -func (s *source) populateSchema() (err error) { +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { var collections []string - s.schema = schema.NewDatabaseSchema() + d.schema = schema.NewDatabaseSchema() // Get database name. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {sqlgen.Raw{`DATABASE()`}}, - }, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.RawValue(`DATABASE()`), + ), } var row *sqlx.Row - if row, err = s.doQueryRow(stmt); err != nil { + if row, err = d.QueryRow(stmt); err != nil { return err } - if err = row.Scan(&s.schema.Name); err != nil { + if err = row.Scan(&d.schema.Name); err != nil { return err } // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if collections, err = d.Collections(); err != nil { return err } for i := range collections { // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { + if _, err = d.Collection(collections[i]); err != nil { return err } } @@ -436,31 +426,39 @@ func (s *source) populateSchema() (err error) { return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.tables`}, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.tables`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -474,32 +472,40 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.columns`}, - Columns: sqlgen.Columns{ - {`column_name`}, - {`data_type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.columns`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`column_name`), + sqlgen.ColumnWithName(`data_type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -509,54 +515,64 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -func (s *source) getPrimaryKey(tableName string) ([]string, error) { +func (d *database) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { return tableSchema.PrimaryKey, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{ - sqlgen.Raw{` + Type: sqlgen.Select, + Table: sqlgen.RawValue(` information_schema.table_constraints AS t JOIN information_schema.key_column_usage k USING(constraint_name, table_schema, table_name) - `}, - }, - Columns: sqlgen.Columns{ - {`k.column_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`t.constraint_type`}, `=`, sqlgen.Value{`primary key`}}, - sqlgen.ColumnValue{sqlgen.Column{`t.table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`t.table_name`}, `=`, sqlPlaceholder}, - }, - OrderBy: sqlgen.OrderBy{ - sqlgen.SortColumns{ - { - sqlgen.Column{`k.ordinal_position`}, - sqlgen.SqlSortAsc, - }, + `), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`k.column_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.constraint_type`), + Operator: `=`, + Value: sqlgen.NewValue(`primary key`), + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.table_name`), + Operator: `=`, + Value: sqlPlaceholder, }, + ), + OrderBy: &sqlgen.OrderBy{ + SortColumns: sqlgen.JoinSortColumns( + &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(`k.ordinal_position`), + Order: sqlgen.Ascendent, + }, + ), }, } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -572,40 +588,3 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { return tableSchema.PrimaryKey, nil } - -func init() { - - template = &sqlgen.Template{ - mysqlColumnSeparator, - mysqlIdentifierSeparator, - mysqlIdentifierQuote, - mysqlValueSeparator, - mysqlValueQuote, - mysqlAndKeyword, - mysqlOrKeyword, - mysqlNotKeyword, - mysqlDescKeyword, - mysqlAscKeyword, - mysqlDefaultOperator, - mysqlClauseGroup, - mysqlClauseOperator, - mysqlColumnValue, - mysqlTableAliasLayout, - mysqlColumnAliasLayout, - mysqlSortByColumnLayout, - mysqlWhereLayout, - mysqlOrderByLayout, - mysqlInsertLayout, - mysqlSelectLayout, - mysqlUpdateLayout, - mysqlDeleteLayout, - mysqlTruncateLayout, - mysqlDropDatabaseLayout, - mysqlDropTableLayout, - mysqlSelectCountLayout, - mysqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) -} diff --git a/mysql/database_test.go b/mysql/database_test.go index bda28252950b806086046028666c945384924c63..624dcc20f1b7db9c27e54763156153d1eec5bf82 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -40,9 +40,9 @@ import ( ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" ) const ( @@ -50,7 +50,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, Options: map[string]string{ @@ -179,7 +179,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: password, @@ -193,7 +193,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: "fail", @@ -217,7 +217,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: "fail", Password: password, @@ -234,7 +234,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, Host: host, @@ -1470,7 +1470,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec("TRUNCATE TABLE `artist`"); err != nil { b.Fatal(err) @@ -1524,7 +1524,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/mysql/mysql.go b/mysql/mysql.go new file mode 100644 index 0000000000000000000000000000000000000000..6cdf2d36b7d8aa56b163581a11b4c10f05d41ec8 --- /dev/null +++ b/mysql/mysql.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package mysql + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `mysql` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/mysql/result.go b/mysql/result.go deleted file mode 100644 index b4fe80c81bece49e71c06c8dea3872f8d2b0dbb9..0000000000000000000000000000000000000000 --- a/mysql/result.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package mysql - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"_t"` -} - -type result struct { - table *table - cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - if err = sqlutil.FetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return nil -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values) - if err != nil { - return err - } - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() (err error) { - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { - var count counter - - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - err = row.Scan(&count.Total) - if err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/mysql/layout.go b/mysql/template.go similarity index 63% rename from mysql/layout.go rename to mysql/template.go index c0ef34060f173ded0bbe07bc95f3c03611f6cf8f..1b2d21dab87f67e0f0cc87ed872d9ba9f37b1112 100644 --- a/mysql/layout.go +++ b/mysql/template.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,37 +22,38 @@ package mysql const ( - mysqlColumnSeparator = `.` - mysqlIdentifierSeparator = `, ` - mysqlIdentifierQuote = "`{{.Raw}}`" - mysqlValueSeparator = `, ` - mysqlValueQuote = `'{{.}}'` - mysqlAndKeyword = `AND` - mysqlOrKeyword = `OR` - mysqlNotKeyword = `NOT` - mysqlDescKeyword = `DESC` - mysqlAscKeyword = `ASC` - mysqlDefaultOperator = `=` - mysqlClauseGroup = `({{.}})` - mysqlClauseOperator = ` {{.}} ` - mysqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - mysqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - mysqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - mysqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - mysqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = "`{{.Value}}`" + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - mysqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - mysqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -79,19 +80,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - mysqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - mysqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - mysqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -106,7 +107,7 @@ const ( {{end}} ` - mysqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -114,23 +115,21 @@ const ( {{.Extra}} ` - mysqlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} ` - mysqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - mysqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - mysqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - mysqlNull = `NULL` ) diff --git a/mysql/tx.go b/mysql/tx.go deleted file mode 100644 index 844f11d019bf7365d5689135a0621802bdddabea..0000000000000000000000000000000000000000 --- a/mysql/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package mysql - -import ( - "github.com/jmoiron/sqlx" -) - -type tx struct { - *source - sqlTx *sqlx.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/postgresql/_dumps/Makefile b/postgresql/_dumps/Makefile index 675048196a990e22c46df5ac45a4eac2ea0c9818..83de0013912ec7f3ef76d69fd8bb4fd45f5b5a54 100644 --- a/postgresql/_dumps/Makefile +++ b/postgresql/_dumps/Makefile @@ -1,2 +1,4 @@ +TEST_HOST ?= 127.0.0.1 + all: - cat structs.sql | PGPASSWORD="upperio" psql -Uupperio upperio_tests -htestserver.local + cat structs.sql | PGPASSWORD="upperio" psql -Uupperio upperio_tests -h$(TEST_HOST) diff --git a/postgresql/_dumps/structs.sql b/postgresql/_dumps/structs.sql index 0b4069022cefa2a685b64745a97d2bdb97cdab06..25a210f573e6ba623a789e25db6cda8da9b7fd50 100644 --- a/postgresql/_dumps/structs.sql +++ b/postgresql/_dumps/structs.sql @@ -64,3 +64,12 @@ CREATE TABLE composite_keys ( some_val varchar(255) default '', primary key (code, user_id) ); + +DROP TABLE IF EXISTS option_types; + +CREATE TABLE option_types ( + id serial primary key, + name varchar(255) default '', + tags varchar(64)[], + settings jsonb +); diff --git a/postgresql/collection.go b/postgresql/collection.go index ea1e3a687dcd136fcd7c208712369e1bc0e6ce78..0158c17dd6f8b646cc5a94e685a46dfbb4e725b0 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -24,189 +24,34 @@ package postgresql import ( "database/sql" "fmt" - "reflect" "strings" "github.com/jmoiron/sqlx" "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source + *database primaryKey string - names []string } -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{psqlNull}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} +var _ = db.Collection(&table{}) +// Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: t, - where: where, - arguments: arguments, - } - - return result -} - -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -// Deletes all the rows within the collection. +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - _, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{t.tableN(0)}, + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { @@ -216,43 +61,24 @@ func (t *table) Truncate() error { return nil } -// Appends an item (map or struct) into the collection. +// Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - cols, vals, err := t.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - var columns sqlgen.Columns - - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - var values sqlgen.Values - var arguments []interface{} - - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } var pKey []string - if pKey, err = t.source.getPrimaryKey(t.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -260,17 +86,17 @@ func (t *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{t.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } // No primary keys defined. if len(pKey) == 0 { var res sql.Result - if res, err = t.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -285,7 +111,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { // A primary key was found. stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) - if rows, err = t.source.doQuery(stmt, arguments...); err != nil { + if rows, err = t.database.Query(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -332,14 +158,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. +// Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) + return strings.Join(t.Tables, `, `) } diff --git a/postgresql/connection.go b/postgresql/connection.go index b1d6ab909f1a42a678e13e50b8084ce53bbfeb3f..451383ce96af9a707eb03d9e354405127da4aeca 100644 --- a/postgresql/connection.go +++ b/postgresql/connection.go @@ -37,8 +37,8 @@ type scanner struct { i int } -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. +// Next returns the next rune. It returns 0, false if the end of the text has +// been reached. func (s *scanner) Next() (rune, bool) { if s.i >= len(s.s) { return 0, false @@ -48,8 +48,8 @@ func (s *scanner) Next() (rune, bool) { return r, true } -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. +// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the +// end of the text has been reached. func (s *scanner) SkipSpaces() (rune, bool) { r, ok := s.Next() for unicode.IsSpace(r) && ok { diff --git a/postgresql/database.go b/postgresql/database.go index 158f9a8d04db07ffe47e425561ab2279e20ded1c..a95c014c567687ee94704e32e9294f79f1dd93db 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -24,74 +24,58 @@ package postgresql import ( "database/sql" "fmt" - "os" "strconv" "strings" "time" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" // PostgreSQL driver. - "upper.io/cache" "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" -) - -const ( - // Adapter is the public name of the adapter. - Adapter = `postgresql` + "upper.io/db/util/sqlutil/tx" ) var ( - // Query template - template *sqlgen.Template - - // Query statement placeholder - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema } -type columnSchemaT struct { - Name string `db:"column_name"` - DataType string `db:"data_type"` +type tx struct { + *sqltx.Tx + *database } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() - } +type columnSchemaT struct { + Name string `db:"column_name"` + DataType string `db:"data_type"` } -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { var err error // Before db.ConnectionURL we used a unified db.Settings struct. This // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { + if settings, ok := d.connURL.(db.Settings); ok { - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. conn := ConnectionURL{ User: settings.User, Password: settings.Password, @@ -102,30 +86,31 @@ func (s *source) Open() error { }, } - // Replace original s.connURL - s.connURL = conn + d.connURL = conn } - if s.session, err = sqlx.Open(`postgres`, s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`postgres`, d.connURL.String()); err != nil { return err } - s.session.Mapper = sqlutil.NewMapper() + d.session.Mapper = sqlutil.NewMapper() - if err = s.populateSchema(); err != nil { + if err = d.populateSchema(); err != nil { return err } return nil } -func (s *source) Clone() (db.Database, error) { - return s.clone() +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func (s *source) clone() (*source, error) { - src := new(source) - src.Setup(s.connURL) +func (d *database) clone() (*database, error) { + src := new(database) + src.Setup(d.connURL) if err := src.Open(); err != nil { return nil, err @@ -134,38 +119,37 @@ func (s *source) clone() (*source, error) { return src, nil } -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() } return nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { var err error if len(names) == 0 { return nil, db.ErrMissingCollectionName } - if s.tx != nil { - if s.tx.done { + if d.tx != nil { + if d.tx.Done() { return nil, sql.ErrTxDone } } - col := &table{ - source: s, - names: names, - } + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper for _, name := range names { chunks := strings.SplitN(name, ` `, 2) @@ -176,11 +160,11 @@ func (s *source) Collection(names ...string) (db.Collection, error) { tableName := chunks[0] - if err := s.tableExists(tableName); err != nil { + if err := d.tableExists(tableName); err != nil { return nil, err } - if col.Columns, err = s.tableColumns(tableName); err != nil { + if col.Columns, err = d.tableColumns(tableName); err != nil { return nil, err } } @@ -188,41 +172,38 @@ func (s *source) Collection(names ...string) (db.Collection, error) { return col, nil } -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - tablesInSchema := len(s.schema.Tables) + tablesInSchema := len(d.schema.Tables) // Is schema already populated? if tablesInSchema > 0 { // Pulling table names from schema. - return s.schema.Tables, nil + return d.schema.Tables, nil } // Schema is empty. // Querying table names. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Table: sqlgen.Table{ - `information_schema.tables`, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_schema`}, - `=`, - sqlgen.Value{`public`}, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Table: sqlgen.TableWithName(`information_schema.tables`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlgen.NewValue(`public`), }, - }, + ), } // Executing statement. var rows *sqlx.Rows - if rows, err = s.doQuery(stmt); err != nil { + if rows, err = d.Query(stmt); err != nil { return nil, err } @@ -239,7 +220,7 @@ func (s *source) Collections() (collections []string, err error) { } // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.AddTable(name) // Adding table to collections array. collections = append(collections, name) @@ -248,62 +229,63 @@ func (s *source) Collections() (collections []string, err error) { return collections, nil } -// Changes the active database. -func (s *source) Use(database string) (err error) { +// Use changes the active database. +func (d *database) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } conn.Database = database - s.connURL = conn + d.connURL = conn - return s.Open() + return d.Open() } -// Drops the currently active database. -func (s *source) Drop() error { - _, err := s.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlDropDatabase, - Database: sqlgen.Database{s.schema.Name}, +// Drop removes all tables from the current database. +func (d *database) Drop() error { + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), }) return err } -// Stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { - s.connURL = connURL - return s.Open() +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name } -func (s *source) Transaction() (db.Tx, error) { +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { var err error - var clone *source + var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = s.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = s.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx + clone.tx = sqltx.New(sqlTx) - return tx, nil + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -313,30 +295,31 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - res, err = s.session.Exec(query, args...) + res, err = d.session.Exec(query, args...) } return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -346,30 +329,31 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - rows, err = s.session.Queryx(query, args...) + rows, err = d.session.Queryx(query, args...) } return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -379,60 +363,60 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, `$`+strconv.Itoa(i+1), 1) } - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - row = s.session.QueryRowx(query, args...) + row = d.session.QueryRowx(query, args...) } return row, err } -func (s *source) populateSchema() (err error) { +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { var collections []string - s.schema = schema.NewDatabaseSchema() + d.schema = schema.NewDatabaseSchema() // Get database name. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {sqlgen.Raw{`CURRENT_DATABASE()`}}, - }, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.RawValue(`CURRENT_DATABASE()`), + ), } var row *sqlx.Row - if row, err = s.doQueryRow(stmt); err != nil { + if row, err = d.QueryRow(stmt); err != nil { return err } - if err = row.Scan(&s.schema.Name); err != nil { + if err = row.Scan(&d.schema.Name); err != nil { return err } - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if collections, err = d.Collections(); err != nil { return err } for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { + if _, err = d.Collection(collections[i]); err != nil { return err } } @@ -440,37 +424,45 @@ func (s *source) populateSchema() (err error) { return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.tables`}, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_catalog`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.tables`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_catalog`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { return db.ErrCollectionDoesNotExist } defer rows.Close() - if rows.Next() == false { + if !rows.Next() { return db.ErrCollectionDoesNotExist } } @@ -478,42 +470,40 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{ - `information_schema.columns`, - }, - Columns: sqlgen.Columns{ - {`column_name`}, - {`data_type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_catalog`}, - `=`, - sqlPlaceholder, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.columns`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`column_name`), + sqlgen.ColumnWithName(`data_type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_catalog`), + Operator: `=`, + Value: sqlPlaceholder, }, - sqlgen.ColumnValue{ - sqlgen.Column{`table_name`}, - `=`, - sqlPlaceholder, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, }, - }, + ), } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -525,17 +515,17 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -func (s *source) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) +func (d *database) getPrimaryKey(tableName string) ([]string, error) { + tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { return tableSchema.PrimaryKey, nil @@ -543,35 +533,37 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { // Getting primary key. See https://github.com/upper/db/issues/24. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`}, - Columns: sqlgen.Columns{ - {`pg_attribute.attname`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`pg_class.oid`}, `=`, sqlgen.Value{sqlgen.Raw{`'"` + tableName + `"'::regclass`}}}, - sqlgen.ColumnValue{sqlgen.Column{`indrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, - sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, - sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attnum`}, `=`, sqlgen.Value{sqlgen.Raw{`any(pg_index.indkey)`}}}, - sqlgen.Raw{`indisprimary`}, - }, - OrderBy: sqlgen.OrderBy{ - sqlgen.SortColumns{ - { - sqlgen.Column{`attname`}, - sqlgen.SqlSortAsc, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`pg_index, pg_class, pg_attribute`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`pg_attribute.attname`), + ), + Where: sqlgen.WhereConditions( + sqlgen.RawValue(`pg_class.oid = '"`+tableName+`"'::regclass`), + sqlgen.RawValue(`indrelid = pg_class.oid`), + sqlgen.RawValue(`pg_attribute.attrelid = pg_class.oid`), + sqlgen.RawValue(`pg_attribute.attnum = ANY(pg_index.indkey)`), + sqlgen.RawValue(`indisprimary`), + ), + OrderBy: &sqlgen.OrderBy{ + SortColumns: sqlgen.JoinSortColumns( + &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(`attname`), + Order: sqlgen.Ascendent, }, - }, + ), }, } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt); err != nil { + if rows, err = d.Query(stmt); err != nil { return nil, err } + defer rows.Close() + tableSchema.PrimaryKey = make([]string, 0, 1) for rows.Next() { @@ -584,39 +576,3 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { return tableSchema.PrimaryKey, nil } - -func init() { - template = &sqlgen.Template{ - pgsqlColumnSeparator, - pgsqlIdentifierSeparator, - pgsqlIdentifierQuote, - pgsqlValueSeparator, - pgsqlValueQuote, - pgsqlAndKeyword, - pgsqlOrKeyword, - pgsqlNotKeyword, - pgsqlDescKeyword, - pgsqlAscKeyword, - pgsqlDefaultOperator, - pgsqlClauseGroup, - pgsqlClauseOperator, - pgsqlColumnValue, - pgsqlTableAliasLayout, - pgsqlColumnAliasLayout, - pgsqlSortByColumnLayout, - pgsqlWhereLayout, - pgsqlOrderByLayout, - pgsqlInsertLayout, - pgsqlSelectLayout, - pgsqlUpdateLayout, - pgsqlDeleteLayout, - pgsqlTruncateLayout, - pgsqlDropDatabaseLayout, - pgsqlDropTableLayout, - pgsqlSelectCountLayout, - pgsqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) -} diff --git a/postgresql/database_test.go b/postgresql/database_test.go index fd83b09a0d5db685c111d56cdde7560b9e42d4a9..97e41f4896162ccbd5d5e5c55734b1cb24e70d5a 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -40,9 +40,9 @@ import ( ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" ) const ( @@ -50,7 +50,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, Options: map[string]string{ @@ -175,13 +175,13 @@ func TestOpenFailed(t *testing.T) { } // Attempts to open an empty datasource. -func TestOpenWithWrongData(t *testing.T) { +func SkipTestOpenWithWrongData(t *testing.T) { var err error var rightSettings, wrongSettings db.Settings // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: password, @@ -195,9 +195,9 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, + Database: "fail", Host: host, - User: username, + User: "fail", Password: "fail", } @@ -219,7 +219,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: "fail", Password: password, @@ -236,7 +236,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, Host: host, @@ -496,34 +496,37 @@ func TestResultFetch(t *testing.T) { res.Close() - // Dumping into an struct with no tags. - rowStruct := struct { - ID uint64 - Name string - }{} - - res = artist.Find() - - for { - err = res.Next(&rowStruct) - - if err == db.ErrNoMoreRows { - break - } + // NOTE: tags are required.. unless a different type mapper + // is specified.. - if err == nil { - if rowStruct.ID == 0 { - t.Fatalf("Expecting a not null ID.") - } - if rowStruct.Name == "" { - t.Fatalf("Expecting a name.") - } - } else { - t.Fatal(err) - } - } + // Dumping into an struct with no tags. + // rowStruct := struct { + // ID uint64 `db:"id,omitempty"` + // Name string `db:"name"` + // }{} + + // res = artist.Find() + + // for { + // err = res.Next(&rowStruct) + + // if err == db.ErrNoMoreRows { + // break + // } + + // if err == nil { + // if rowStruct.ID == 0 { + // t.Fatalf("Expecting a not null ID.") + // } + // if rowStruct.Name == "" { + // t.Fatalf("Expecting a name.") + // } + // } else { + // t.Fatal(err) + // } + // } - res.Close() + // res.Close() // Dumping into a tagged struct. rowStruct2 := struct { @@ -574,8 +577,8 @@ func TestResultFetch(t *testing.T) { // Dumping into a slice of structs. allRowsStruct := []struct { - ID uint64 - Name string + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` }{} res = artist.Find() @@ -710,6 +713,70 @@ func TestResultFetchAll(t *testing.T) { } } +func TestInlineStructs(t *testing.T) { + var sess db.Database + var err error + + var review db.Collection + + type reviewTypeDetails struct { + Name string `db:"name"` + Comments string `db:"comments"` + Created time.Time `db:"created"` + } + + type reviewType struct { + ID int64 `db:"id,omitempty"` + PublicationID int64 `db:"publication_id"` + Details reviewTypeDetails `db:",inline"` + } + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if review, err = sess.Collection("review"); err != nil { + t.Fatal(err) + } + + if err = review.Truncate(); err != nil { + t.Fatal(err) + } + + rec := reviewType{ + PublicationID: 123, + Details: reviewTypeDetails{ + Name: "..name..", Comments: "..comments..", + }, + } + + id, err := review.Append(rec) + if err != nil { + t.Fatal(err) + } + if id.(int64) <= 0 { + t.Fatal("bad id") + } + rec.ID = id.(int64) + + var recChk reviewType + err = review.Find().One(&recChk) + + if err != nil { + t.Fatal(err) + } + + if recChk.ID != rec.ID { + t.Fatal("ID of review does not match, expecting:", rec.ID, "got:", recChk.ID) + } + if recChk.Details.Name != rec.Details.Name { + t.Fatal("Name of inline field does not match, expecting:", + rec.Details.Name, "got:", recChk.Details.Name) + } +} + // Attempts to modify previously added rows. func TestUpdate(t *testing.T) { var err error @@ -728,8 +795,8 @@ func TestUpdate(t *testing.T) { // Defining destination struct value := struct { - ID uint64 - Name string + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` }{} // Getting the first artist. @@ -760,7 +827,7 @@ func TestUpdate(t *testing.T) { // Updating set with a struct rowStruct := struct { - Name string + Name string `db:"name"` }{strings.ToLower(value.Name)} if err = res.Update(rowStruct); err != nil { @@ -1584,6 +1651,281 @@ func TestDataTypes(t *testing.T) { } } +func TestOptionTypes(t *testing.T) { + var err error + var sess db.Database + var optionTypes db.Collection + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if optionTypes, err = sess.Collection("option_types"); err != nil { + t.Fatal(err) + } + + if err = optionTypes.Truncate(); err != nil { + t.Fatal(err) + } + + // TODO: lets do some benchmarking on these auto-wrapped option types.. + + // TODO: add nullable jsonb field mapped to a []string + + // A struct with wrapped option types defined in the struct tags + // for postgres string array and jsonb types + type optionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags,stringarray"` + Settings map[string]interface{} `db:"settings,jsonb"` + } + + // Item 1 + item1 := optionType{ + Name: "Food", + Tags: []string{"toronto", "pizza"}, + Settings: map[string]interface{}{"a": 1, "b": 2}, + } + + id, err := optionTypes.Append(item1) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1Chk optionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1Chk); err != nil { + t.Fatal(err) + } + + if item1Chk.Settings["a"].(float64) != 1 { // float64 because of json.. + t.Fatalf("Expecting Settings['a'] of jsonb value to be 1") + } + + if item1Chk.Tags[0] != "toronto" { + t.Fatalf("Expecting first element of Tags stringarray to be 'toronto'") + } + + // Item 1 B + item1b := &optionType{ + Name: "Golang", + Tags: []string{"love", "it"}, + Settings: map[string]interface{}{"go": 1, "lang": 2}, + } + + id, err = optionTypes.Append(item1b) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1bChk optionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1bChk); err != nil { + t.Fatal(err) + } + + if item1bChk.Settings["go"].(float64) != 1 { // float64 because of json.. + t.Fatalf("Expecting Settings['go'] of jsonb value to be 1") + } + + if item1bChk.Tags[0] != "love" { + t.Fatalf("Expecting first element of Tags stringarray to be 'love'") + } + + // Item 1 C + item1c := &optionType{ + Name: "Sup", Tags: []string{}, Settings: map[string]interface{}{}, + } + + id, err = optionTypes.Append(item1c) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1cChk optionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1cChk); err != nil { + t.Fatal(err) + } + + if len(item1cChk.Tags) != 0 { + t.Fatalf("Expecting tags array to be empty but is %v", item1cChk.Tags) + } + + if len(item1cChk.Settings) != 0 { + t.Fatalf("Expecting Settings map to be empty") + } + + // An option type to pointer jsonb field + type optionType2 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags,stringarray"` + Settings *map[string]interface{} `db:"settings,jsonb"` + } + + item2 := optionType2{ + Name: "JS", Tags: []string{"hi", "bye"}, Settings: nil, + } + + id, err = optionTypes.Append(item2) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item2Chk optionType2 + res := optionTypes.Find(db.Cond{"id": id}) + if err := res.One(&item2Chk); err != nil { + t.Fatal(err) + } + + if item2Chk.ID != id.(int64) { + t.Fatalf("Expecting id to match") + } + + if item2Chk.Name != item2.Name { + t.Fatalf("Expecting Name to match") + } + + if item2Chk.Tags[0] != item2.Tags[0] || len(item2Chk.Tags) != len(item2.Tags) { + t.Fatalf("Expecting tags to match") + } + + // Update the value + m := map[string]interface{}{} + m["lang"] = "javascript" + m["num"] = 31337 + item2.Settings = &m + err = res.Update(item2) + if err != nil { + t.Fatal(err) + } + + if err := res.One(&item2Chk); err != nil { + t.Fatal(err) + } + + if (*item2Chk.Settings)["num"].(float64) != 31337 { // float64 because of json.. + t.Fatalf("Expecting Settings['num'] of jsonb value to be 31337") + } + + if (*item2Chk.Settings)["lang"] != "javascript" { + t.Fatalf("Expecting Settings['lang'] of jsonb value to be 'javascript'") + } + + // An option type to pointer string array field + type optionType3 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags *[]string `db:"tags,stringarray"` + Settings map[string]interface{} `db:"settings,jsonb"` + } + + item3 := optionType3{ + Name: "Julia", Tags: nil, Settings: map[string]interface{}{"girl": true, "lang": true}, + } + + id, err = optionTypes.Append(item3) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item3Chk optionType2 + if err := optionTypes.Find(db.Cond{"id": id}).One(&item3Chk); err != nil { + t.Fatal(err) + } +} + +func TestOptionTypeJsonbStruct(t *testing.T) { + var err error + var sess db.Database + var optionTypes db.Collection + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + if optionTypes, err = sess.Collection("option_types"); err != nil { + t.Fatal(err) + } + + if err = optionTypes.Truncate(); err != nil { + t.Fatal(err) + } + + // A struct with wrapped option types defined in the struct tags + // for postgres string array and jsonb types + type Settings struct { + Name string `json:"name"` + Num int64 `json:"num"` + } + + type OptionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags,stringarray"` + Settings Settings `db:"settings,jsonb"` + } + + item1 := &OptionType{ + Name: "Hi", + Tags: []string{"aah", "ok"}, + Settings: Settings{Name: "a", Num: 123}, + } + + id, err := optionTypes.Append(item1) + if err != nil { + t.Fatal(err) + } + + if pk, ok := id.(int64); !ok || pk == 0 { + t.Fatalf("Expecting an ID.") + } + + var item1Chk OptionType + if err := optionTypes.Find(db.Cond{"id": id}).One(&item1Chk); err != nil { + t.Fatal(err) + } + + if len(item1Chk.Tags) != 2 { + t.Fatalf("Expecting 2 tags") + } + + if item1Chk.Tags[0] != "aah" { + t.Fatalf("Expecting first tag to be 0") + } + + if item1Chk.Settings.Name != "a" { + t.Fatalf("Expecting Name to be 'a'") + } + + if item1Chk.Settings.Num != 123 { + t.Fatalf("Expecting Num to be 123") + } +} + // We are going to benchmark the engine, so this is no longed needed. func TestDisableDebug(t *testing.T) { os.Setenv(db.EnvEnableDebug, "") @@ -1600,7 +1942,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec(`TRUNCATE TABLE "artist"`); err != nil { b.Fatal(err) @@ -1654,7 +1996,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go new file mode 100644 index 0000000000000000000000000000000000000000..7e8363cc3b1de831ce83948f0ed1baf59fadc51b --- /dev/null +++ b/postgresql/postgresql.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package postgresql + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `postgresql` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/postgresql/result.go b/postgresql/result.go deleted file mode 100644 index de0bb1d47053d133eeae423290f27c0a4360540a..0000000000000000000000000000000000000000 --- a/postgresql/result.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package postgresql - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"_t"` -} - -type result struct { - table *table - cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - if err = sqlutil.FetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return nil -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values) - if err != nil { - return err - } - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() (err error) { - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { - var count counter - - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - err = row.Scan(&count.Total) - if err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/postgresql/layout.go b/postgresql/template.go similarity index 65% rename from postgresql/layout.go rename to postgresql/template.go index 1a1bd8d8ac6690ec932c23e8fdfbed664f33e790..aa414ecbb85e1065fa52a2145eb403ac69e6c690 100644 --- a/postgresql/layout.go +++ b/postgresql/template.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,37 +22,38 @@ package postgresql const ( - pgsqlColumnSeparator = `.` - pgsqlIdentifierSeparator = `, ` - pgsqlIdentifierQuote = `"{{.Raw}}"` - pgsqlValueSeparator = `, ` - pgsqlValueQuote = `'{{.}}'` - pgsqlAndKeyword = `AND` - pgsqlOrKeyword = `OR` - pgsqlNotKeyword = `NOT` - pgsqlDescKeyword = `DESC` - pgsqlAscKeyword = `ASC` - pgsqlDefaultOperator = `=` - pgsqlClauseGroup = `({{.}})` - pgsqlClauseOperator = ` {{.}} ` - pgsqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - pgsqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - pgsqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - pgsqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - pgsqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -79,19 +80,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - pgsqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - pgsqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - pgsqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -106,7 +107,7 @@ const ( {{end}} ` - pgsqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -114,23 +115,21 @@ const ( {{.Extra}} ` - pgsqlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} RESTART IDENTITY ` - pgsqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - pgsqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - pgsqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - psqlNull = `NULL` ) diff --git a/ql/_dumps/Makefile b/ql/_dumps/Makefile index cba56b0150415fb40b311ce885b8d59a497ff8b8..1703af5222b0476e569a542c0cc26cdefd13489d 100644 --- a/ql/_dumps/Makefile +++ b/ql/_dumps/Makefile @@ -1,3 +1,3 @@ all: rm -f test.db - cat structs.sql | ql -db test.db + cat structs.sql | $$GOPATH/bin/ql -db test.db diff --git a/ql/collection.go b/ql/collection.go index a63b7ad89ade057e6342a0e8e8c6ed591cbf00f2..53a5d4c1b7e98e0237a9a42278ab7973bda83495 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,231 +22,68 @@ package ql import ( - "fmt" + "database/sql" "reflect" "strings" "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `==` - type table struct { sqlutil.T - columnTypes map[string]reflect.Kind - source *source + *database names []string + columnTypes map[string]reflect.Kind } -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} +var _ = db.Collection(&table{}) +// Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: t, - where: where, - arguments: arguments, - } - - return result + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - - _, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{t.tableN(0)}, + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. +// Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - cols, vals, err := t.FieldValues(item) - - var columns sqlgen.Columns - var values sqlgen.Values + columnNames, columnValues, err := t.FieldValues(item) - for _, col := range cols { - columns = append(columns, sqlgen.Column{col}) + if err != nil { + return nil, err } - for i := 0; i < len(vals); i++ { - values = append(values, sqlPlaceholder) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - // Error ocurred, stop appending. if err != nil { return nil, err } - res, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{t.tableN(0)}, - Columns: columns, - Values: values, - }, vals...) + stmt := sqlgen.Statement{ + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, + } - if err != nil { + var res sql.Result + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -263,14 +100,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { return id, nil } -// Returns true if the collection exists. +// Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) + return strings.Join(t.Tables, `, `) } diff --git a/ql/database.go b/ql/database.go index 1a1c5623b1c1815b373d6cb75f32ffcc6eaa6153..c6529044b60e1aca67f8371c3ae32a918d262eb0 100644 --- a/ql/database.go +++ b/ql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -24,121 +24,250 @@ package ql import ( "database/sql" "fmt" - "os" "strings" "time" _ "github.com/cznic/ql/driver" // QL driver "github.com/jmoiron/sqlx" - "upper.io/cache" "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" -) - -const ( - // Adapter is the public name of the adapter. - Adapter = `ql` + "upper.io/db/util/sqlutil/tx" ) var ( - template *sqlgen.Template - - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema } +type tx struct { + *sqltx.Tx + *database +} + +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) + type columnSchemaT struct { Name string `db:"Name"` } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session +} + +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error + + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + Database: settings.Database, + } + + d.connURL = conn } - return false + + if d.session, err = sqlx.Open(`ql`, d.connURL.String()); err != nil { + return err + } + + d.session.Mapper = sqlutil.NewMapper() + + if err = d.populateSchema(); err != nil { + return err + } + + return nil } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() +} + +func (d *database) clone() (adapter *database, err error) { + adapter = new(database) + + if err = adapter.Setup(d.connURL); err != nil { + return nil, err } + + return adapter, nil } -func init() { - - template = &sqlgen.Template{ - qlColumnSeparator, - qlIdentifierSeparator, - qlIdentifierQuote, - qlValueSeparator, - qlValueQuote, - qlAndKeyword, - qlOrKeyword, - qlNotKeyword, - qlDescKeyword, - qlAscKeyword, - qlDefaultOperator, - qlClauseGroup, - qlClauseOperator, - qlColumnValue, - qlTableAliasLayout, - qlColumnAliasLayout, - qlSortByColumnLayout, - qlWhereLayout, - qlOrderByLayout, - qlInsertLayout, - qlSelectLayout, - qlUpdateLayout, - qlDeleteLayout, - qlTruncateLayout, - qlDropDatabaseLayout, - qlDropTableLayout, - qlSelectCountLayout, - qlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -func (s *source) populateSchema() (err error) { - var collections []string +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil +} + +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error + + if len(names) == 0 { + return nil, db.ErrMissingCollectionName + } + + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } + + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper + + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) + + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } + + tableName := chunks[0] + + if err := d.tableExists(tableName); err != nil { + return nil, err + } + + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } + } - s.schema = schema.NewDatabaseSchema() + return col, nil +} + +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { + + tablesInSchema := len(d.schema.Tables) + + // Is schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil + } + + // Schema is empty. + + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Table`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + ), + } + + // Executing statement. + var rows *sqlx.Rows + if rows, err = d.Query(stmt); err != nil { + return nil, err + } + + defer rows.Close() + + collections = []string{} + + var name string + + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err + } + + // Adding table entry to schema. + d.schema.AddTable(name) + // Adding table to collections array. + collections = append(collections, name) + } + + return collections, nil +} + +// Use changes the active database. +func (d *database) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } - s.schema.Name = conn.Database + conn.Database = database - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { - return err + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + return db.ErrUnsupported +} + +// Setup stores database settings. +func (d *database) Setup(conn db.ConnectionURL) error { + d.connURL = conn + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name +} + +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err - } + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err } - return err + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -148,26 +277,26 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { var tx *sqlx.Tx - if tx, err = s.session.Beginx(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -183,7 +312,8 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -193,26 +323,26 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { var tx *sqlx.Tx - if tx, err = s.session.Beginx(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -228,7 +358,8 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -238,26 +369,26 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { var tx *sqlx.Tx - if tx, err = s.session.Beginx(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -273,194 +404,64 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return row, err } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} - -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) clone() (adapter *source, err error) { - adapter = new(source) - - if err = adapter.Setup(s.connURL); err != nil { - return nil, err - } - - return adapter, nil -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} - -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sqlx.Tx - - if clone, err = s.clone(); err != nil { - return nil, err - } - - if sqlTx, err = s.session.Beginx(); err != nil { - return nil, err - } - - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(conn db.ConnectionURL) error { - s.connURL = conn - return s.Open() -} - -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session -} - -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { - var err error - - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - Database: settings.Database, - } - - s.connURL = conn - } - - if s.session, err = sqlx.Open(`ql`, s.connURL.String()); err != nil { - return err - } - - s.session.Mapper = sqlutil.NewMapper() - - if err = s.populateSchema(); err != nil { - return err - } +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - return nil -} + d.schema = schema.NewDatabaseSchema() -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} - -// Changes the active database. -func (s *source) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } - conn.Database = database - - s.connURL = conn - - return s.Open() -} - -// Drops the currently active database. -func (s *source) Drop() error { - return db.ErrUnsupported -} - -// Returns a list of all tables within the currently active database. -func (s *source) Collections() (collections []string, err error) { - - tablesInSchema := len(s.schema.Tables) + d.schema.Name = conn.Database - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil - } - - // Schema is empty. - - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Table`}, - Columns: sqlgen.Columns{ - {`Name`}, - }, - } - - // Executing statement. - var rows *sqlx.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { + return err } - defer rows.Close() - - collections = []string{} - - var name string - - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err } - - // Adding table entry to schema. - s.schema.AddTable(name) - - // Adding table to collections array. - collections = append(collections, name) } - return collections, nil + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Table`}, - Columns: sqlgen.Columns{ - {`Name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`Name`}, `==`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Table`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`Name`), + Operator: `==`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, names[i]); err != nil { + if rows, err = d.Query(stmt, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -474,31 +475,35 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Column`}, - Columns: sqlgen.Columns{ - {`Name`}, - {`Type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`TableName`}, `==`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Column`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + sqlgen.ColumnWithName(`Type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`TableName`), + Operator: `==`, + Value: sqlPlaceholder, + }, + ), } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, tableName); err != nil { + if rows, err = d.Query(stmt, tableName); err != nil { return nil, err } @@ -508,51 +513,11 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil -} - -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } - - col := &table{ - source: s, - names: names, - } - - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) - - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName - } - - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } - - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err - } - } - - return col, nil + return d.schema.TableInfo[tableName].Columns, nil } diff --git a/ql/database_test.go b/ql/database_test.go index dd389d34df9da070abd3dde6eee1700c7fd280e6..43957d73fc4f7e13ea2fb1424f0fa1a33ff2832d 100644 --- a/ql/database_test.go +++ b/ql/database_test.go @@ -46,7 +46,7 @@ import ( ) const ( - database = `_dumps/test.db` + databaseName = `_dumps/test.db` ) const ( @@ -54,7 +54,7 @@ const ( ) var settings = db.Settings{ - Database: database, + Database: databaseName, } // Structure for testing conversions and datatypes. @@ -155,7 +155,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, } // Opening database. @@ -942,6 +942,7 @@ func TestRawQuery(t *testing.T) { var rows *sqlx.Rows var err error var drv *sqlx.DB + var tx *sqlx.Tx type publicationType struct { ID int64 `db:"id,omitempty"` @@ -957,7 +958,11 @@ func TestRawQuery(t *testing.T) { drv = sess.Driver().(*sqlx.DB) - rows, err = drv.Queryx(` + if tx, err = drv.Beginx(); err != nil { + t.Fatal(err) + } + + if rows, err = tx.Queryx(` SELECT p.id AS id, p.title AS publication_title, @@ -967,9 +972,11 @@ func TestRawQuery(t *testing.T) { (SELECT id() AS id, title, author_id FROM publication) AS p WHERE a.id == p.author_id - `) + `); err != nil { + t.Fatal(err) + } - if err != nil { + if err = tx.Commit(); err != nil { t.Fatal(err) } diff --git a/ql/ql.go b/ql/ql.go new file mode 100644 index 0000000000000000000000000000000000000000..acffe02cff797d9a84fcb0e853d938a68dd08d51 --- /dev/null +++ b/ql/ql.go @@ -0,0 +1,72 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package ql + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `ql` + +var template *sqlutil.TemplateWithUtils + +func init() { + + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/ql/result.go b/ql/result.go deleted file mode 100644 index 60985da8abcbd586424ff99aa8544b32decfdf74..0000000000000000000000000000000000000000 --- a/ql/result.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"total"` -} - -type result struct { - table *table - cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - if err = sqlutil.FetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return nil -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values) - if err != nil { - return err - } - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() (err error) { - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { - var count counter - - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - err = row.Scan(&count.Total) - if err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/ql/layout.go b/ql/template.go similarity index 65% rename from ql/layout.go rename to ql/template.go index d60f32e2d49b1bb55ba2f4dcae6dae061e31c609..61fc47c044905a107393e34539f62e3f19e1cebc 100644 --- a/ql/layout.go +++ b/ql/template.go @@ -22,37 +22,38 @@ package ql const ( - qlColumnSeparator = `.` - qlIdentifierSeparator = `, ` - qlIdentifierQuote = `{{.Raw}}` - qlValueSeparator = `, ` - qlValueQuote = `"{{.}}"` - qlAndKeyword = `&&` - qlOrKeyword = `||` - qlNotKeyword = `!=` - qlDescKeyword = `DESC` - qlAscKeyword = `ASC` - qlDefaultOperator = `==` - qlClauseGroup = `({{.}})` - qlClauseOperator = ` {{.}} ` - qlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - qlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - qlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - qlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - qlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `{{.Value}}` + adapterValueSeparator = `, ` + adapterValueQuote = `"{{.}}"` + adapterAndKeyword = `&&` + adapterOrKeyword = `||` + adapterNotKeyword = `!=` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `==` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - qlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - qlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -77,19 +78,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - qlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - qlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - qlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT count(1) AS total FROM {{.Table}} @@ -104,7 +105,7 @@ const ( {{end}} ` - qlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -112,19 +113,19 @@ const ( {{.Extra}} ` - qlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} ` - qlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - qlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - qlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} diff --git a/ql/tx.go b/ql/tx.go deleted file mode 100644 index 6093f5aa25f5ae6e53a327dc30a06606f8a7e84e..0000000000000000000000000000000000000000 --- a/ql/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "github.com/jmoiron/sqlx" -) - -type tx struct { - *source - sqlTx *sqlx.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/sqlite/collection.go b/sqlite/collection.go index c65e64d4a1790f5f4e0337c268399d8e5c9d5f2d..0aadfcd010ce8ae64d1e4ce1bc6321b3d69bd9c4 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,8 +22,6 @@ package sqlite import ( - "fmt" - "reflect" "strings" "database/sql" @@ -31,228 +29,52 @@ import ( "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source - names []string -} - -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args + *database } -func (c *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: c, - where: where, - arguments: arguments, - } +var _ = db.Collection(&table{}) - return result -} - -func (c *table) tableN(i int) string { - if len(c.names) > i { - chunks := strings.SplitN(c.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -// Deletes all the rows within the collection. -func (c *table) Truncate() error { - - _, err := c.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{c.tableN(0)}, +// Truncate deletes all rows from the table. +func (t *table) Truncate() error { + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. -func (c *table) Append(item interface{}) (interface{}, error) { - +// Append inserts an item (map or struct) into the collection. +func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - cols, vals, err := c.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) - // Error ocurred, stop appending. if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = c.source.getPrimaryKey(c.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -260,14 +82,14 @@ func (c *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{c.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = c.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -294,10 +116,10 @@ func (c *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -320,14 +142,15 @@ func (c *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. -func (c *table) Exists() bool { - if err := c.source.tableExists(c.names...); err != nil { +// Exists returns true if the collection exists. +func (t *table) Exists() bool { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } -func (c *table) Name() string { - return strings.Join(c.names, `, `) +// Name returns the name of the table or tables that form the collection. +func (t *table) Name() string { + return strings.Join(t.Tables, `, `) } diff --git a/sqlite/database.go b/sqlite/database.go index d0998b154dd6ef506609a7c82fc057bb7433c94a..22de80bb25b0fbb068dfe720649e18137d903a95 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -24,35 +24,26 @@ package sqlite import ( "database/sql" "fmt" - "os" "strings" "time" - // Importing SQLite3 driver. "github.com/jmoiron/sqlx" - _ "github.com/mattn/go-sqlite3" - "upper.io/cache" + _ "github.com/mattn/go-sqlite3" // SQLite3 driver. "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/tx" ) -const ( - // Adapter is the public name of the adapter. - Adapter = `sqlite` -) - -var template *sqlgen.Template - var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema // columns property was introduced so we could query PRAGMA data only once // and retrieve all the column information we'd need, such as name and if it @@ -60,399 +51,392 @@ type source struct { columns map[string][]columnSchemaT } -type columnSchemaT struct { - Name string `db:"name"` - PK int `db:"pk"` +type tx struct { + *sqltx.Tx + *database } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() - } +type columnSchemaT struct { + Name string `db:"name"` + PK int `db:"pk"` } -func init() { - - template = &sqlgen.Template{ - sqlColumnSeparator, - sqlIdentifierSeparator, - sqlIdentifierQuote, - sqlValueSeparator, - sqlValueQuote, - sqlAndKeyword, - sqlOrKeyword, - sqlNotKeyword, - sqlDescKeyword, - sqlAscKeyword, - sqlDefaultOperator, - sqlClauseGroup, - sqlClauseOperator, - sqlColumnValue, - sqlTableAliasLayout, - sqlColumnAliasLayout, - sqlSortByColumnLayout, - sqlWhereLayout, - sqlOrderByLayout, - sqlInsertLayout, - sqlSelectLayout, - sqlUpdateLayout, - sqlDeleteLayout, - sqlTruncateLayout, - sqlDropDatabaseLayout, - sqlDropTableLayout, - sqlSelectCountLayout, - sqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -func (s *source) populateSchema() (err error) { - var collections []string +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error - s.schema = schema.NewDatabaseSchema() + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + Database: settings.Database, + Options: map[string]string{ + "cache": "shared", + }, + } - var conn ConnectionURL + d.connURL = conn + } - if conn, err = ParseURL(s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`sqlite3`, d.connURL.String()); err != nil { return err } - s.schema.Name = conn.Database + d.session.Mapper = sqlutil.NewMapper() - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if err = d.populateSchema(); err != nil { return err } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err - } - } + return nil +} - return err +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { - var query string - var res sql.Result - var err error - var start, end int64 +func (d *database) clone() (*database, error) { + src := &database{} + src.Setup(d.connURL) - start = time.Now().UnixNano() + if err := src.Open(); err != nil { + return nil, err + } - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + return src, nil +} - if s.session == nil { - return nil, db.ErrNotConnected +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() +} + +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() } + return nil +} - query = stmt.Compile(template) +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) - } else { - res, err = s.session.Exec(query, args...) + if len(names) == 0 { + return nil, db.ErrMissingCollectionName } - return res, err -} + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { - var rows *sqlx.Rows - var query string - var err error - var start, end int64 + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper - start = time.Now().UnixNano() + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } - if s.session == nil { - return nil, db.ErrNotConnected - } + tableName := chunks[0] - query = stmt.Compile(template) + if err := d.tableExists(tableName); err != nil { + return nil, err + } - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) - } else { - rows, err = s.session.Queryx(query, args...) + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } } - return rows, err + return col, nil } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { - var query string - var row *sqlx.Row - var err error - var start, end int64 +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - start = time.Now().UnixNano() + tablesInSchema := len(d.schema.Tables) - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() - - if s.session == nil { - return nil, db.ErrNotConnected + // Id.schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil } - query = stmt.Compile(template) + // Schema is empty. - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) - } else { - row = s.session.QueryRowx(query, args...) + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`tbl_name`), + ), + Table: sqlgen.TableWithName(`sqlite_master`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`type`), + Operator: `=`, + Value: sqlgen.NewValue(`table`), + }, + ), } - return row, err -} - -func (s *source) doRawQuery(query string, args ...interface{}) (*sqlx.Rows, error) { + // Executing statement. var rows *sqlx.Rows - var err error - var start, end int64 + if rows, err = d.Query(stmt); err != nil { + return nil, err + } - start = time.Now().UnixNano() + defer rows.Close() - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + collections = []string{} - if s.session == nil { - return nil, db.ErrNotConnected - } + var name string - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) - } else { - rows, err = s.session.Queryx(query, args...) + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err + } + + // Adding table entry to schema. + d.schema.AddTable(name) + + // Adding table to collections array. + collections = append(collections, name) } - return rows, err + return collections, nil } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL + + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err + } + + conn.Database = database + + d.connURL = conn -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() + return d.Open() } -func (s *source) clone() (*source, error) { - src := &source{} - src.Setup(s.connURL) +// Drop removes all tables from the current database. +func (d *database) Drop() error { - if err := src.Open(); err != nil { - return nil, err - } + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) - return src, nil + return err +} + +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() } -func (s *source) Clone() (db.Database, error) { - return s.clone() +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name } -func (s *source) Transaction() (db.Tx, error) { +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { var err error - var clone *source + var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = s.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = s.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(conn db.ConnectionURL) error { - s.connURL = conn - return s.Open() -} + clone.tx = sqltx.New(sqlTx) -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session + return tx{Tx: clone.tx, database: clone}, nil } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { + var query string + var res sql.Result var err error + var start, end int64 - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - Database: settings.Database, - Options: map[string]string{ - "cache": "shared", - }, - } + start = time.Now().UnixNano() - s.connURL = conn - } + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - if s.session, err = sqlx.Open(`sqlite3`, s.connURL.String()); err != nil { - return err + if d.session == nil { + return nil, db.ErrNotConnected } - s.session.Mapper = sqlutil.NewMapper() + query = stmt.Compile(template.Template) - if err = s.populateSchema(); err != nil { - return err + if d.tx != nil { + res, err = d.tx.Exec(query, args...) + } else { + res, err = d.session.Exec(query, args...) } - return nil + return res, err } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows + var query string + var err error + var start, end int64 -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL + start = time.Now().UnixNano() - if conn, err = ParseURL(s.connURL.String()); err != nil { - return err + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() + + if d.session == nil { + return nil, db.ErrNotConnected } - conn.Database = database + query = stmt.Compile(template.Template) - s.connURL = conn + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) + } else { + rows, err = d.session.Queryx(query, args...) + } - return s.Open() + return rows, err } -// Drops the currently active database. -func (s *source) Drop() error { - return db.ErrUnsupported -} +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { + var query string + var row *sqlx.Row + var err error + var start, end int64 -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { + start = time.Now().UnixNano() - tablesInSchema := len(s.schema.Tables) + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil + if d.session == nil { + return nil, db.ErrNotConnected } - // Schema is empty. + query = stmt.Compile(template.Template) - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`tbl_name`}, - }, - Table: sqlgen.Table{`sqlite_master`}, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`type`}, - `=`, - sqlgen.Value{`table`}, - }, - }, + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) + } else { + row = d.session.QueryRowx(query, args...) } - // Executing statement. - var rows *sqlx.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err - } + return row, err +} - defer rows.Close() +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - collections = []string{} + d.schema = schema.NewDatabaseSchema() - var name string + var conn ConnectionURL - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err - } + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err + } - // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.Name = conn.Database - // Adding table to collections array. - collections = append(collections, name) + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { + return err } - return collections, nil + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err + } + } + + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`sqlite_master`}, - Columns: sqlgen.Columns{ - {`tbl_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`type`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`tbl_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`sqlite_master`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`tbl_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`type`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`tbl_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, `table`, names[i]); err != nil { + if rows, err = d.Query(stmt, `table`, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -466,10 +450,10 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil @@ -477,10 +461,10 @@ func (s *source) tableColumns(tableName string) ([]string, error) { q := fmt.Sprintf(`PRAGMA TABLE_INFO('%s')`, tableName) - rows, err := s.doRawQuery(q) + rows, err := d.doRawQuery(q) - if s.columns == nil { - s.columns = make(map[string][]columnSchemaT) + if d.columns == nil { + d.columns = make(map[string][]columnSchemaT) } columns := []columnSchemaT{} @@ -489,81 +473,64 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.columns[tableName] = columns + d.columns[tableName] = columns - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(s.columns)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(d.columns)) - for i := range s.columns[tableName] { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, s.columns[tableName][i].Name) + for i := range d.columns[tableName] { + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, d.columns[tableName][i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } +func (d *database) getPrimaryKey(tableName string) ([]string, error) { + tableSchema := d.schema.Table(tableName) - col := &table{ - source: s, - names: names, - } + d.tableColumns(tableName) - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) + maxValue := -1 - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName + for i := range d.columns[tableName] { + if d.columns[tableName][i].PK > 0 && d.columns[tableName][i].PK > maxValue { + maxValue = d.columns[tableName][i].PK } + } - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } + if maxValue > 0 { + tableSchema.PrimaryKey = make([]string, maxValue) - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err + for i := range d.columns[tableName] { + if d.columns[tableName][i].PK > 0 { + tableSchema.PrimaryKey[d.columns[tableName][i].PK-1] = d.columns[tableName][i].Name + } } } - return col, nil + return tableSchema.PrimaryKey, nil } -// getPrimaryKey returns the names of the columns that define the primary key -// of the table. -func (s *source) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) +func (d *database) doRawQuery(query string, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows + var err error + var start, end int64 - s.tableColumns(tableName) + start = time.Now().UnixNano() - maxValue := -1 + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - for i := range s.columns[tableName] { - if s.columns[tableName][i].PK > 0 && s.columns[tableName][i].PK > maxValue { - maxValue = s.columns[tableName][i].PK - } + if d.session == nil { + return nil, db.ErrNotConnected } - if maxValue > 0 { - tableSchema.PrimaryKey = make([]string, maxValue) - - for i := range s.columns[tableName] { - if s.columns[tableName][i].PK > 0 { - tableSchema.PrimaryKey[s.columns[tableName][i].PK-1] = s.columns[tableName][i].Name - } - } + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) + } else { + rows, err = d.session.Queryx(query, args...) } - return tableSchema.PrimaryKey, nil + return rows, err } diff --git a/sqlite/database_test.go b/sqlite/database_test.go index 4480e392fb2145268def9cd0fdf15f5e245b573c..da54e350e4550ed95064f733db6e3c27a28ac2d6 100644 --- a/sqlite/database_test.go +++ b/sqlite/database_test.go @@ -47,7 +47,7 @@ import ( ) const ( - database = `_dumps/gotest.sqlite3.db` + databaseName = `_dumps/gotest.sqlite3.db` ) const ( @@ -55,7 +55,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, } // Structure for testing conversions and datatypes. @@ -167,7 +167,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, } // Opening database. @@ -1378,10 +1378,24 @@ func TestDataTypes(t *testing.T) { loc, _ := time.LoadLocation(testTimeZone) item.Date = item.Date.In(loc) + // TODO: Try to guess this conversion. + if item.DateP.Location() != testValues.DateP.Location() { + v := item.DateP.In(testValues.DateP.Location()) + item.DateP = &v + } + // The original value and the test subject must match. - if reflect.DeepEqual(item, testValues) == false { - fmt.Printf("item1: %v\n", item) - fmt.Printf("test2: %v\n", testValues) + if !reflect.DeepEqual(item, testValues) { + fmt.Printf("item1: %#v\n", item) + fmt.Printf("test2: %#v\n", testValues) + fmt.Printf("item1: %#v\n", item.Date.String()) + fmt.Printf("test2: %#v\n", testValues.Date.String()) + fmt.Printf("item1: %v\n", item.Date.Location().String()) + fmt.Printf("test2: %v\n", testValues.Date.Location().String()) + fmt.Printf("item1: %#v\n", item.DateP) + fmt.Printf("test2: %#v\n", testValues.DateP) + fmt.Printf("item1: %v\n", item.DateP.Location().String()) + fmt.Printf("test2: %v\n", testValues.DateP.Location().String()) t.Fatalf("Struct is different.") } } @@ -1402,7 +1416,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec(`DELETE FROM "artist"`); err != nil { b.Fatal(err) @@ -1456,7 +1470,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go new file mode 100644 index 0000000000000000000000000000000000000000..ea66cafdedeb1a49e361414cda975acdfb47c827 --- /dev/null +++ b/sqlite/sqlite.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sqlite + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `sqlite` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/sqlite/layout.go b/sqlite/template.go similarity index 65% rename from sqlite/layout.go rename to sqlite/template.go index 9d937696e41b9247a817b61f8507774850b62424..abea16614305df659ce23b2b33a84df02a01da90 100644 --- a/sqlite/layout.go +++ b/sqlite/template.go @@ -22,37 +22,38 @@ package sqlite const ( - sqlColumnSeparator = `.` - sqlIdentifierSeparator = `, ` - sqlIdentifierQuote = `"{{.Raw}}"` - sqlValueSeparator = `, ` - sqlValueQuote = `'{{.}}'` - sqlAndKeyword = `AND` - sqlOrKeyword = `OR` - sqlNotKeyword = `NOT` - sqlDescKeyword = `DESC` - sqlAscKeyword = `ASC` - sqlDefaultOperator = `=` - sqlClauseGroup = `({{.}})` - sqlClauseOperator = ` {{.}} ` - sqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - sqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - sqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - sqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - sqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - sqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - sqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -75,24 +76,24 @@ const ( {{if .Offset}} {{if not .Limit}} - LIMIT -1 + LIMIT -1 {{end}} OFFSET {{.Offset}} {{end}} ` - sqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - sqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - sqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -104,13 +105,13 @@ const ( {{if .Offset}} {{if not .Limit}} - LIMIT -1 + LIMIT -1 {{end}} OFFSET {{.Offset}} {{end}} ` - sqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -118,23 +119,21 @@ const ( {{.Extra}} ` - sqlTruncateLayout = ` + adapterTruncateLayout = ` DELETE FROM {{.Table}} ` - sqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - sqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - sqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - sqlNull = `NULL` ) diff --git a/sqlite/tx.go b/sqlite/tx.go deleted file mode 100644 index dbfc7698441d78012ba56494a4054149da254179..0000000000000000000000000000000000000000 --- a/sqlite/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package sqlite - -import ( - "github.com/jmoiron/sqlx" -) - -type tx struct { - *source - sqlTx *sqlx.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/util/main.go b/util/main.go deleted file mode 100644 index 0447145c6b20b0cd73869fae8727b32ab35c1d58..0000000000000000000000000000000000000000 --- a/util/main.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package util - -import ( - "regexp" - "strings" -) - -var reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`) - -type tagOptions map[string]bool - -func parseTagOptions(s string) tagOptions { - opts := make(tagOptions) - chunks := strings.Split(s, `,`) - for _, chunk := range chunks { - opts[strings.TrimSpace(chunk)] = true - } - return opts -} - -// ParseTag splits a struct tag into comma separated chunks. The first chunk is -// returned as a string value, remaining chunks are considered enabled options. -func ParseTag(tag string) (string, tagOptions) { - // Based on http://golang.org/src/pkg/encoding/json/tags.go - if i := strings.Index(tag, `,`); i != -1 { - return tag[:i], parseTagOptions(tag[i+1:]) - } - return tag, parseTagOptions(``) -} - -// NormalizeColumn prepares a column for comparison against another column. -func NormalizeColumn(s string) string { - return strings.ToLower(reColumnCompareExclude.ReplaceAllString(s, "")) -} diff --git a/util/schema/main.go b/util/schema/schema.go similarity index 100% rename from util/schema/main.go rename to util/schema/schema.go diff --git a/util/sqlgen/benchmark_test.go b/util/sqlgen/benchmark_test.go deleted file mode 100644 index 923950352ff8da46fdd6d2026175fcfb36250149..0000000000000000000000000000000000000000 --- a/util/sqlgen/benchmark_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package sqlgen - -import ( - "fmt" - "math/rand" - "testing" -) - -func BenchmarkColumn(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Column{"a"} - } -} - -func BenchmarkCompileColumn(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Column{Value: "a"}.Compile(defaultTemplate) - } -} - -func BenchmarkColumns(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Columns{{"a"}, {"b"}, {"c"}} - } -} - -func BenchmarkCompileColumns(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Columns{{"a"}, {"b"}, {"c"}}.Compile(defaultTemplate) - } -} - -func BenchmarkValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Value{"a"} - } -} - -func BenchmarkCompileValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Value{"a"}.Compile(defaultTemplate) - } -} - -func BenchmarkValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Values{{"a"}, {"b"}, {"c"}, {1}, {2}, {3}} - } -} - -func BenchmarkCompileValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Values{{"a"}, {"b"}, {"c"}, {1}, {2}, {3}}.Compile(defaultTemplate) - } -} - -func BenchmarkDatabase(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Database{"TestDatabase"} - } -} - -func BenchmarkCompileDatabase(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Database{"TestDatabase"}.Compile(defaultTemplate) - } -} - -func BenchmarkValueRaw(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Value{Raw{"a"}} - } -} - -func BenchmarkColumnValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}} - } -} - -func BenchmarkCompileColumnValue(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}}.Compile(defaultTemplate) - } -} - -func BenchmarkColumnValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValues{{Column{"a"}, "=", Value{Raw{"7"}}}} - } -} - -func BenchmarkCompileColumnValues(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = ColumnValues{{Column{"a"}, "=", Value{Raw{"7"}}}}.Compile(defaultTemplate) - } -} - -func BenchmarkOrderBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = OrderBy{ - SortColumns: SortColumns{ - SortColumn{Column: Column{"foo"}}, - }, - } - } -} - -func BenchmarkCompileOrderBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = OrderBy{ - SortColumns: SortColumns{ - SortColumn{Column: Column{"foo"}}, - }, - }.Compile(defaultTemplate) - } -} - -func BenchmarkGroupBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = GroupBy{ - Column{"foo"}, - } - } -} - -func BenchmarkCompileGroupBy(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = GroupBy{ - Column{"foo"}, - }.Compile(defaultTemplate) - } -} - -func BenchmarkWhere(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - } - } -} - -func BenchmarkCompileWhere(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }.Compile(defaultTemplate) - } -} - -func BenchmarkTable(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = Table{"foo"} - } -} - -func BenchmarkCompileTable(b *testing.B) { - var t string - for i := 0; i < b.N; i++ { - t = Table{"foo"}.Compile(defaultTemplate) - if t != `"foo"` { - b.Fatal("Caching failed.") - } - } -} - -func BenchmarkCompileRandomTable(b *testing.B) { - var t string - var m, n int - var s, e string - - for i := 0; i < b.N; i++ { - m, n = rand.Int(), rand.Int() - s = fmt.Sprintf(`%s as %s`, m, n) - e = fmt.Sprintf(`"%s" AS "%s"`, m, n) - - t = Table{s}.Compile(defaultTemplate) - if t != e { - b.Fatal() - } - } -} - -func BenchmarkCompileSelect(b *testing.B) { - var stmt Statement - - for i := 0; i < b.N; i++ { - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}}, - }, - } - _ = stmt.Compile(defaultTemplate) - } -} diff --git a/util/sqlgen/column.go b/util/sqlgen/column.go index 6a69c98de0ec65b3126cbba9d49754f78c5f573d..4eb395e173a4d2634bbe0c2519508c07f891bfed 100644 --- a/util/sqlgen/column.go +++ b/util/sqlgen/column.go @@ -5,41 +5,58 @@ import ( "strings" ) -type column_t struct { +type columnT struct { Name string Alias string } +// Column represents a SQL column. type Column struct { - Value interface{} + Name interface{} + hash string } -func (self Column) Hash() string { - switch t := self.Value.(type) { - case cc: - return `Column(` + t.Hash() + `)` - case string: - return `Column(` + t + `)` +// ColumnWithName creates and returns a Column with the given name. +func ColumnWithName(name string) *Column { + return &Column{Name: name} +} + +// Hash returns a unique identifier. +func (c *Column) Hash() string { + if c.hash == "" { + var s string + + switch t := c.Name.(type) { + case Fragment: + s = t.Hash() + case fmt.Stringer: + s = t.String() + case string: + s = t + default: + s = fmt.Sprintf("%v", c.Name) + } + + c.hash = fmt.Sprintf(`Column{Name:%q}`, s) } - return fmt.Sprintf(`Column(%v)`, self.Value) + + return c.hash } -func (self Column) Compile(layout *Template) (compiled string) { +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *Column) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - switch value := self.Value.(type) { + switch value := c.Name.(type) { case string: - // input := strings.TrimSpace(value) input := trimString(value) - //chunks := reAliasSeparator.Split(input, 2) chunks := separateByAS(input) if len(chunks) == 1 { - //chunks = reSpaceSeparator.Split(input, 2) chunks = separateBySpace(input) } @@ -48,9 +65,8 @@ func (self Column) Compile(layout *Template) (compiled string) { nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) for i := range nameChunks { - // nameChunks[i] = strings.TrimSpace(nameChunks[i]) nameChunks[i] = trimString(nameChunks[i]) - nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{nameChunks[i]}) + nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) } name = strings.Join(nameChunks, layout.ColumnSeparator) @@ -58,19 +74,18 @@ func (self Column) Compile(layout *Template) (compiled string) { var alias string if len(chunks) > 1 { - // alias = strings.TrimSpace(chunks[1]) alias = trimString(chunks[1]) - alias = mustParse(layout.IdentifierQuote, Raw{alias}) + alias = mustParse(layout.IdentifierQuote, Raw{Value: alias}) } - compiled = mustParse(layout.ColumnAliasLayout, column_t{name, alias}) + compiled = mustParse(layout.ColumnAliasLayout, columnT{name, alias}) case Raw: compiled = value.String() default: - compiled = fmt.Sprintf("%v", self.Value) + compiled = fmt.Sprintf("%v", c.Name) } - layout.Write(self, compiled) + layout.Write(c, compiled) return } diff --git a/util/sqlgen/column_test.go b/util/sqlgen/column_test.go index 62f3929def3ffcc912e5312bab6affcc4ad5d1c3..e5fdb7d825644a8e76b249da7e65fe20d225b336 100644 --- a/util/sqlgen/column_test.go +++ b/util/sqlgen/column_test.go @@ -1,13 +1,27 @@ package sqlgen import ( + "fmt" "testing" ) +func TestColumnHash(t *testing.T) { + var s, e string + + column := Column{Name: "role.name"} + + s = column.Hash() + e = fmt.Sprintf(`Column{Name:"%s"}`, column.Name) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestColumnString(t *testing.T) { var s, e string - column := Column{"role.name"} + column := Column{Name: "role.name"} s = column.Compile(defaultTemplate) e = `"role"."name"` @@ -20,7 +34,7 @@ func TestColumnString(t *testing.T) { func TestColumnAs(t *testing.T) { var s, e string - column := Column{"role.name as foo"} + column := Column{Name: "role.name as foo"} s = column.Compile(defaultTemplate) e = `"role"."name" AS "foo"` @@ -33,7 +47,7 @@ func TestColumnAs(t *testing.T) { func TestColumnImplicitAs(t *testing.T) { var s, e string - column := Column{"role.name foo"} + column := Column{Name: "role.name foo"} s = column.Compile(defaultTemplate) e = `"role"."name" AS "foo"` @@ -46,7 +60,7 @@ func TestColumnImplicitAs(t *testing.T) { func TestColumnRaw(t *testing.T) { var s, e string - column := Column{Raw{"role.name As foo"}} + column := Column{Name: Raw{Value: "role.name As foo"}} s = column.Compile(defaultTemplate) e = `role.name As foo` @@ -55,3 +69,51 @@ func TestColumnRaw(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkColumnWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ColumnWithName("a") + } +} + +func BenchmarkColumnHash(b *testing.B) { + c := Column{Name: "name"} + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnCompile(b *testing.B) { + c := Column{Name: "name"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Column{Name: "name"} + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithDotCompile(b *testing.B) { + c := Column{Name: "role.name"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name foo"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name AS foo"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/column_value.go b/util/sqlgen/column_value.go index fbe57987c43aec83b0ed24b42cb15156dd82c3f9..485a46f35d09dfd3efd7132d2c03cf091a421d88 100644 --- a/util/sqlgen/column_value.go +++ b/util/sqlgen/column_value.go @@ -1,71 +1,93 @@ package sqlgen import ( + "fmt" "strings" ) +// ColumnValue represents a bundle between a column and a corresponding value. type ColumnValue struct { - Column + Column Fragment Operator string - Value + Value Fragment + hash string } -type columnValue_s struct { +type columnValueT struct { Column string Operator string Value string } -func (self ColumnValue) Hash() string { - return `ColumnValue(` + self.Column.Hash() + `;` + self.Operator + `;` + self.Value.Hash() + `)` +// Hash returns a unique identifier. +func (c *ColumnValue) Hash() string { + if c.hash == "" { + c.hash = fmt.Sprintf(`ColumnValue{Name:%q, Operator:%q, Value:%q}`, c.Column.Hash(), c.Operator, c.Value.Hash()) + } + return c.hash } -func (self ColumnValue) Compile(layout *Template) (compiled string) { +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *ColumnValue) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - data := columnValue_s{ - self.Column.Compile(layout), - self.Operator, - self.Value.Compile(layout), + data := columnValueT{ + c.Column.Compile(layout), + c.Operator, + c.Value.Compile(layout), } compiled = mustParse(layout.ColumnValue, data) - layout.Write(self, compiled) + layout.Write(c, compiled) return } -type ColumnValues []ColumnValue +// ColumnValues represents an array of ColumnValue +type ColumnValues struct { + ColumnValues []Fragment + hash string +} + +// JoinColumnValues returns an array of ColumnValue +func JoinColumnValues(values ...Fragment) *ColumnValues { + return &ColumnValues{ColumnValues: values} +} -func (self ColumnValues) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +// Hash returns a unique identifier. +func (c *ColumnValues) Hash() string { + if c.hash == "" { + s := make([]string, len(c.ColumnValues)) + for i := range c.ColumnValues { + s[i] = c.ColumnValues[i].Hash() + } + c.hash = fmt.Sprintf("ColumnValues{ColumnValues:{%s}}", strings.Join(s, ", ")) } - return `ColumnValues(` + strings.Join(hash, `,`) + `)` + return c.hash } -func (self ColumnValues) Compile(layout *Template) (compiled string) { +// Compile transforms the ColumnValues into its SQL representation. +func (c *ColumnValues) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - l := len(self) + l := len(c.ColumnValues) out := make([]string, l) - for i := 0; i < l; i++ { - out[i] = self[i].Compile(layout) + for i := range c.ColumnValues { + out[i] = c.ColumnValues[i].Compile(layout) } compiled = strings.Join(out, layout.IdentifierSeparator) - layout.Write(self, compiled) + layout.Write(c, compiled) return } diff --git a/util/sqlgen/column_value_test.go b/util/sqlgen/column_value_test.go index b535ad1c245c928f97a9060746d9199b60d96253..9a954697805520941033a3366af7a605b979bb7a 100644 --- a/util/sqlgen/column_value_test.go +++ b/util/sqlgen/column_value_test.go @@ -1,14 +1,45 @@ package sqlgen import ( + "fmt" "testing" ) +func TestColumnValueHash(t *testing.T) { + var s, e string + + c := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + + s = c.Hash() + e = fmt.Sprintf(`ColumnValue{Name:%q, Operator:%q, Value:%q}`, c.Column.Hash(), c.Operator, c.Value.Hash()) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestColumnValuesHash(t *testing.T) { + var s, e string + + c := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)}, + &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(2)}, + ) + + s = c.Hash() + + e = fmt.Sprintf(`ColumnValues{ColumnValues:{%s, %s}}`, c.ColumnValues[0].Hash(), c.ColumnValues[1].Hash()) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestColumnValue(t *testing.T) { var s, e string - var cv ColumnValue + var cv *ColumnValue - cv = ColumnValue{Column{"id"}, "=", Value{1}} + cv = &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} s = cv.Compile(defaultTemplate) e = `"id" = '1'` @@ -17,7 +48,7 @@ func TestColumnValue(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } - cv = ColumnValue{Column{"date"}, "=", Value{Raw{"NOW()"}}} + cv = &ColumnValue{Column: ColumnWithName("date"), Operator: "=", Value: NewValue(RawValue("NOW()"))} s = cv.Compile(defaultTemplate) e = `"date" = NOW()` @@ -29,15 +60,14 @@ func TestColumnValue(t *testing.T) { func TestColumnValues(t *testing.T) { var s, e string - var cvs ColumnValues - cvs = ColumnValues{ - {Column{"id"}, ">", Value{8}}, - {Column{"other.id"}, "<", Value{Raw{"100"}}}, - {Column{"name"}, "=", Value{"Haruki Murakami"}}, - {Column{"created"}, ">=", Value{Raw{"NOW()"}}}, - {Column{"modified"}, "<=", Value{Raw{"NOW()"}}}, - } + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) s = cvs.Compile(defaultTemplate) e = `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()` @@ -45,5 +75,82 @@ func TestColumnValues(t *testing.T) { if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) } +} +func BenchmarkNewColumnValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = &ColumnValue{Column: ColumnWithName("a"), Operator: "=", Value: NewValue(Raw{Value: "7"})} + } +} + +func BenchmarkColumnValueHash(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + for i := 0; i < b.N; i++ { + cv.Hash() + } +} + +func BenchmarkColumnValueCompile(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + for i := 0; i < b.N; i++ { + cv.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + cv.Compile(defaultTemplate) + } +} + +func BenchmarkJoinColumnValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + } +} + +func BenchmarkColumnValuesHash(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + for i := 0; i < b.N; i++ { + cvs.Hash() + } +} + +func BenchmarkColumnValuesCompile(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + for i := 0; i < b.N; i++ { + cvs.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + cvs.Compile(defaultTemplate) + } } diff --git a/util/sqlgen/columns.go b/util/sqlgen/columns.go index dfe41f2e38c5133550751bd4d1a3501081e1537c..507cb215f55491a1df1e3c1aa84caea16305e7e0 100644 --- a/util/sqlgen/columns.go +++ b/util/sqlgen/columns.go @@ -1,38 +1,53 @@ package sqlgen import ( + "fmt" "strings" ) -type Columns []Column +// Columns represents an array of Column. +type Columns struct { + Columns []Fragment + hash string +} -func (self Columns) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +// Hash returns a unique identifier. +func (c *Columns) Hash() string { + if c.hash == "" { + s := make([]string, len(c.Columns)) + for i := range c.Columns { + s[i] = c.Columns[i].Hash() + } + c.hash = fmt.Sprintf("Columns{Columns:{%s}}", strings.Join(s, ", ")) } - return `Columns(` + strings.Join(hash, `,`) + `)` + return c.hash +} + +// JoinColumns creates and returns an array of Column. +func JoinColumns(columns ...Fragment) *Columns { + return &Columns{Columns: columns} } -func (self Columns) Compile(layout *Template) (compiled string) { +// Compile transforms the Columns into an equivalent SQL representation. +func (c *Columns) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(c); ok { + return z } - l := len(self) + l := len(c.Columns) if l > 0 { out := make([]string, l) for i := 0; i < l; i++ { - out[i] = self[i].Compile(layout) + out[i] = c.Columns[i].Compile(layout) } compiled = strings.Join(out, layout.IdentifierSeparator) } - layout.Write(self, compiled) + layout.Write(c, compiled) return } diff --git a/util/sqlgen/columns_test.go b/util/sqlgen/columns_test.go index 668c0c420c4b72b8c46ab0919f4500ac299dde24..a4f439799f33c766a63ce971364e9d65e8cf1339 100644 --- a/util/sqlgen/columns_test.go +++ b/util/sqlgen/columns_test.go @@ -7,13 +7,13 @@ import ( func TestColumns(t *testing.T) { var s, e string - columns := Columns{ - {"id"}, - {"customer"}, - {"service_id"}, - {"role.name"}, - {"role.id"}, - } + columns := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) s = columns.Compile(defaultTemplate) e = `"id", "customer", "service_id", "role"."name", "role"."id"` @@ -21,5 +21,53 @@ func TestColumns(t *testing.T) { if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) } +} + +func BenchmarkJoinColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkColumnsHash(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnsCompile(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} +func BenchmarkColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + c.Compile(defaultTemplate) + } } diff --git a/util/sqlgen/database.go b/util/sqlgen/database.go index 6c5f731477d0053b264ff9640f03a2288c1e8dbd..df7001dd80c698529ca93371df4186f544ab07c4 100644 --- a/util/sqlgen/database.go +++ b/util/sqlgen/database.go @@ -4,22 +4,34 @@ import ( "fmt" ) +// Database represents a SQL database. type Database struct { - Value string + Name string + hash string } -func (self Database) Hash() string { - return `Database(` + self.Value + `)` +// DatabaseWithName returns a Database with the given name. +func DatabaseWithName(name string) *Database { + return &Database{Name: name} } -func (self Database) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { +// Hash returns a unique identifier. +func (d *Database) Hash() string { + if d.hash == "" { + d.hash = fmt.Sprintf(`Database{Name:%q}`, d.Name) + } + return d.hash +} + +// Compile transforms the Database into an equivalent SQL representation. +func (d *Database) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(d); ok { return c } - compiled = mustParse(layout.IdentifierQuote, Raw{fmt.Sprintf(`%v`, self.Value)}) + compiled = mustParse(layout.IdentifierQuote, Raw{Value: d.Name}) - layout.Write(self, compiled) + layout.Write(d, compiled) return } diff --git a/util/sqlgen/database_test.go b/util/sqlgen/database_test.go new file mode 100644 index 0000000000000000000000000000000000000000..33b1ad8212bcddb7af7618047a3972a27178cada --- /dev/null +++ b/util/sqlgen/database_test.go @@ -0,0 +1,53 @@ +package sqlgen + +import ( + "fmt" + "testing" +) + +func TestDatabaseHash(t *testing.T) { + var s, e string + + column := Database{Name: "users"} + + s = column.Hash() + e = fmt.Sprintf(`Database{Name:"%s"}`, column.Name) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDatabaseCompile(t *testing.T) { + var s, e string + + column := Database{Name: "name"} + + s = column.Compile(defaultTemplate) + e = `"name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkDatabaseHash(b *testing.B) { + c := Database{Name: "name"} + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkDatabaseCompile(b *testing.B) { + c := Database{Name: "name"} + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkDatabaseCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Database{Name: "name"} + c.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/default.go b/util/sqlgen/default.go index 4b2a3f20922f845b401e3af2416e891b2da8a57b..750dc9b3d4ac77f68e2da1bb1a5477875123af06 100644 --- a/util/sqlgen/default.go +++ b/util/sqlgen/default.go @@ -7,7 +7,7 @@ import ( const ( defaultColumnSeparator = `.` defaultIdentifierSeparator = `, ` - defaultIdentifierQuote = `"{{.Raw}}"` + defaultIdentifierQuote = `"{{.Value}}"` defaultValueSeparator = `, ` defaultValueQuote = `'{{.}}'` defaultAndKeyword = `AND` @@ -21,7 +21,7 @@ const ( defaultColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` defaultTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` defaultColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - defaultSortByColumnLayout = `{{.Column}} {{.Sort}}` + defaultSortByColumnLayout = `{{.Column}} {{.Order}}` defaultOrderByLayout = ` {{if .SortColumns}} @@ -72,7 +72,7 @@ const ( {{ .Where }} ` - defaultSelectCountLayout = ` + defaultCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -143,7 +143,7 @@ var defaultTemplate = &Template{ TruncateLayout: defaultTruncateLayout, DropDatabaseLayout: defaultDropDatabaseLayout, DropTableLayout: defaultDropTableLayout, - SelectCountLayout: defaultSelectCountLayout, + CountLayout: defaultCountLayout, GroupByLayout: defaultGroupByLayout, Cache: cache.NewCache(), } diff --git a/util/sqlgen/group_by.go b/util/sqlgen/group_by.go index 28aa812fce9d25ace9f726c627530568aa83a9a2..fe8ed3f34c7d0b05e343432a8561de96a3ae9656 100644 --- a/util/sqlgen/group_by.go +++ b/util/sqlgen/group_by.go @@ -1,31 +1,50 @@ package sqlgen -type GroupBy Columns +import ( + "fmt" +) + +// GroupBy represents a SQL's "group by" statement. +type GroupBy struct { + Columns Fragment + hash string +} -type groupBy_s struct { +type groupByT struct { GroupColumns string } -func (self GroupBy) Hash() string { - return `GroupBy(` + Columns(self).Hash() + `)` +// Hash returns a unique identifier. +func (g *GroupBy) Hash() string { + if g.hash == "" { + if g.Columns != nil { + g.hash = fmt.Sprintf(`GroupBy(%s)`, g.Columns.Hash()) + } + } + return g.hash +} + +// GroupByColumns creates and returns a GroupBy with the given column. +func GroupByColumns(columns ...Fragment) *GroupBy { + return &GroupBy{Columns: JoinColumns(columns...)} } -func (self GroupBy) Compile(layout *Template) (compiled string) { +// Compile transforms the GroupBy into an equivalent SQL representation. +func (g *GroupBy) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { + if c, ok := layout.Read(g); ok { return c } - if len(self) > 0 { - - data := groupBy_s{ - GroupColumns: Columns(self).Compile(layout), + if g.Columns != nil { + data := groupByT{ + GroupColumns: g.Columns.Compile(layout), } compiled = mustParse(layout.GroupByLayout, data) } - layout.Write(self, compiled) + layout.Write(g, compiled) return } diff --git a/util/sqlgen/group_by_test.go b/util/sqlgen/group_by_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c6c6a6f3e34d7bb4c99ac35e3e020e5f1506a389 --- /dev/null +++ b/util/sqlgen/group_by_test.go @@ -0,0 +1,73 @@ +package sqlgen + +import ( + "testing" +) + +func TestGroupBy(t *testing.T) { + var s, e string + + columns := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + + s = columns.Compile(defaultTemplate) + e = `GROUP BY "id", "customer", "service_id", "role"."name", "role"."id"` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkGroupByColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = GroupByColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkGroupByHash(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkGroupByCompile(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + for i := 0; i < b.N; i++ { + c.Compile(defaultTemplate) + } +} + +func BenchmarkGroupByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + c.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/interfaces.go b/util/sqlgen/interfaces.go index 234742c40098cf1696bbe100b25faafa6070482a..8d6cb109034d5890bb47b8aeb25395640e7b9398 100644 --- a/util/sqlgen/interfaces.go +++ b/util/sqlgen/interfaces.go @@ -4,8 +4,8 @@ import ( "upper.io/cache" ) -type cc interface { - cache.Cacheable +type Fragment interface { + cache.Hashable compilable } diff --git a/util/sqlgen/main.go b/util/sqlgen/main.go deleted file mode 100644 index 098c113ced577d2df119eded5cd585b42ea513ce..0000000000000000000000000000000000000000 --- a/util/sqlgen/main.go +++ /dev/null @@ -1,44 +0,0 @@ -package sqlgen - -import ( - "bytes" - "text/template" -) - -type Type uint - -const ( - SqlTruncate = iota - SqlDropTable - SqlDropDatabase - SqlSelectCount - SqlInsert - SqlSelect - SqlUpdate - SqlDelete -) - -type ( - Limit int - Offset int - Extra string -) - -var parsedTemplates = make(map[string]*template.Template) - -func mustParse(text string, data interface{}) (compiled string) { - var b bytes.Buffer - var ok bool - - if _, ok = parsedTemplates[text]; ok == false { - parsedTemplates[text] = template.Must(template.New("").Parse(text)) - } - - if err := parsedTemplates[text].Execute(&b, data); err != nil { - panic("There was an error compiling the following template:\n" + text + "\nError was: " + err.Error()) - } - - compiled = b.String() - - return -} diff --git a/util/sqlgen/main_test.go b/util/sqlgen/main_test.go deleted file mode 100644 index c30c851c4103e52b76fef1f8da5d5d4e56e199d2..0000000000000000000000000000000000000000 --- a/util/sqlgen/main_test.go +++ /dev/null @@ -1,662 +0,0 @@ -package sqlgen - -import ( - "testing" -) - -func TestTruncateTable(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlTruncate, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `TRUNCATE TABLE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestDropTable(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlDropTable, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `DROP TABLE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestDropDatabase(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlDropDatabase, - Database: Database{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `DROP DATABASE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectCount(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectCountRelation(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"information_schema.tables"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "information_schema"."tables"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectCountWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelectCount, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"a"}, "=", Value{Raw{"7"}}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFromAlias(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table.name AS foo"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFromRawWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table.name AS foo"}, - Where: Where{ - Raw{"foo.id = bar.foo_id"}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"table.name AS foo"}, - Where: Where{ - Raw{"foo.id = bar.foo_id"}, - Raw{"baz.id = exp.baz_id"}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectStarFromMany(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"first.table AS foo, second.table as BAR, third.table aS baz"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectArtistNameFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{"artist"}, - Columns: Columns{ - {"artist.name"}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "artist"."name" FROM "artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectRawFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Table: Table{`artist`}, - Columns: Columns{ - {`artist.name`}, - {Raw{`CONCAT(artist.name, " ", artist.last_name)`}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFrom(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWithLimitOffset(t *testing.T) { - var s, e string - var stmt Statement - - // LIMIT only. - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Limit: 42, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // OFFSET only. - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Offset: 17, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // LIMIT AND OFFSET. - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Limit: 42, - Offset: 17, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestGroupBy(t *testing.T) { - var s, e string - var stmt Statement - - // Simple GROUP BY - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - GroupBy: GroupBy{ - Column{"foo"}, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - GroupBy: GroupBy{ - Column{"foo"}, - Column{"bar"}, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWithOrderBy(t *testing.T) { - var s, e string - var stmt Statement - - // Simple ORDER BY - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns: SortColumns{ - SortColumn{Column: Column{"foo"}}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY field ASC - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - SortColumn{Column{"foo"}, SqlSortAsc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY field DESC - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - {Column{"foo"}, SqlSortDesc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY many fields - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - {Column{"foo"}, SqlSortDesc}, - {Column{"bar"}, SqlSortAsc}, - {Column{"baz"}, SqlSortDesc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - // ORDER BY function - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - OrderBy: OrderBy{ - SortColumns{ - {Column{Raw{"FOO()"}}, SqlSortDesc}, - {Column{Raw{"BAR()"}}, SqlSortAsc}, - }, - }, - Table: Table{"table_name"}, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlSelect, - Columns: Columns{ - {"foo"}, - {"bar"}, - {"baz"}, - }, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - Limit: 10, - Offset: 23, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestDelete(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlDelete, - Table: Table{"table_name"}, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `DELETE FROM "table_name" WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestUpdate(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlUpdate, - Table: Table{"table_name"}, - ColumnValues: ColumnValues{ - {Column{"foo"}, "=", Value{76}}, - }, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - stmt = Statement{ - Type: SqlUpdate, - Table: Table{"table_name"}, - ColumnValues: ColumnValues{ - {Column{"foo"}, "=", Value{76}}, - {Column{"bar"}, "=", Value{Raw{"88"}}}, - }, - Where: Where{ - ColumnValue{Column{"baz"}, "=", Value{99}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestInsert(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlInsert, - Table: Table{"table_name"}, - Columns: Columns{ - Column{"foo"}, - Column{"bar"}, - Column{"baz"}, - }, - Values: Values{ - Value{"1"}, - Value{2}, - Value{Raw{"3"}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestInsertExtra(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: SqlInsert, - Table: Table{"table_name"}, - Extra: "RETURNING id", - Columns: Columns{ - Column{"foo"}, - Column{"bar"}, - Column{"baz"}, - }, - Values: Values{ - Value{"1"}, - Value{2}, - Value{Raw{"3"}}, - }, - } - - s = trim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING id` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} diff --git a/util/sqlgen/order_by.go b/util/sqlgen/order_by.go index a88fe69a3cb5e7077ece8af0e4f49d94d1b5c86a..d437cde386ba7df348deff2f302b73e088f378fe 100644 --- a/util/sqlgen/order_by.go +++ b/util/sqlgen/order_by.go @@ -1,112 +1,163 @@ package sqlgen import ( + "fmt" "strings" ) +// Order represents the order in which SQL results are sorted. +type Order uint8 + +// Possible values for Order +const ( + DefaultOrder = Order(iota) + Ascendent + Descendent +) + +// SortColumn represents the column-order relation in an ORDER BY clause. type SortColumn struct { - Column - Sort + Column Fragment + Order + hash string } -type sortColumn_s struct { +type sortColumnT struct { Column string - Sort string + Order string } -type SortColumns []SortColumn +// SortColumns represents the columns in an ORDER BY clause. +type SortColumns struct { + Columns []Fragment + hash string +} -func (self SortColumn) Hash() string { - return `SortColumn(` + self.Column.Hash() + `;` + self.Sort.Hash() + `)` +// OrderBy represents an ORDER BY clause. +type OrderBy struct { + SortColumns Fragment + hash string } -func (self SortColumns) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `SortColumns(` + strings.Join(hash, `,`) + `)` +type orderByT struct { + SortColumns string } -func (self SortColumns) Compile(layout *Template) string { - l := len(self) - s := make([]string, 0, l) - for i := 0; i < l; i++ { - s = append(s, self[i].Compile(layout)) +// JoinSortColumns creates and returns an array of column-order relations. +func JoinSortColumns(values ...Fragment) *SortColumns { + return &SortColumns{Columns: values} +} + +// JoinWithOrderBy creates an returns an OrderBy using the given SortColumns. +func JoinWithOrderBy(sc *SortColumns) *OrderBy { + return &OrderBy{SortColumns: sc} +} + +// Hash returns a unique identifier. +func (s *SortColumn) Hash() string { + if s.hash == "" { + s.hash = fmt.Sprintf(`SortColumn{Column:%s, Order:%s}`, s.Column.Hash(), s.Order.Hash()) } - return strings.Join(s, layout.IdentifierSeparator) + return s.hash } -func (self SortColumn) Compile(layout *Template) (compiled string) { +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *SortColumn) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { + if c, ok := layout.Read(s); ok { return c } - data := sortColumn_s{ - Column: self.Column.Compile(layout), - Sort: self.Sort.Compile(layout), + data := sortColumnT{ + Column: s.Column.Compile(layout), + Order: s.Order.Compile(layout), } compiled = mustParse(layout.SortByColumnLayout, data) - layout.Write(self, compiled) + layout.Write(s, compiled) + return } -type OrderBy struct { - SortColumns +// Hash returns a unique identifier. +func (s *SortColumns) Hash() string { + if s.hash == "" { + h := make([]string, len(s.Columns)) + for i := range s.Columns { + h[i] = s.Columns[i].Hash() + } + s.hash = fmt.Sprintf(`SortColumns(%s)`, strings.Join(h, `, `)) + } + return s.hash } -type orderBy_s struct { - SortColumns string +// Compile transforms the SortColumns into an equivalent SQL representation. +func (s *SortColumns) Compile(layout *Template) (compiled string) { + + if z, ok := layout.Read(s); ok { + return z + } + + z := make([]string, len(s.Columns)) + + for i := range s.Columns { + z[i] = s.Columns[i].Compile(layout) + } + + compiled = strings.Join(z, layout.IdentifierSeparator) + + layout.Write(s, compiled) + + return } -func (self OrderBy) Hash() string { - return `OrderBy(` + self.SortColumns.Hash() + `)` +// Hash returns a unique identifier. +func (s *OrderBy) Hash() string { + if s.hash == "" { + if s.SortColumns != nil { + s.hash = `OrderBy(` + s.SortColumns.Hash() + `)` + } + } + return s.hash } -func (self OrderBy) Compile(layout *Template) (compiled string) { +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *OrderBy) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(s); ok { + return z } - if len(self.SortColumns) > 0 { - data := orderBy_s{ - SortColumns: self.SortColumns.Compile(layout), + if s.SortColumns != nil { + data := orderByT{ + SortColumns: s.SortColumns.Compile(layout), } compiled = mustParse(layout.OrderByLayout, data) } - layout.Write(self, compiled) + layout.Write(s, compiled) return } -type Sort uint8 - -const ( - SqlSortNone = iota - SqlSortAsc - SqlSortDesc -) - -func (self Sort) Hash() string { - switch self { - case SqlSortAsc: - return `Sort(1)` - case SqlSortDesc: - return `Sort(2)` +// Hash returns a unique identifier. +func (s Order) Hash() string { + switch s { + case Ascendent: + return `Order{ASC}` + case Descendent: + return `Order{DESC}` } - return `Sort(0)` + return `Order{DEFAULT}` } -func (self Sort) Compile(layout *Template) string { - switch self { - case SqlSortAsc: +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s Order) Compile(layout *Template) string { + switch s { + case Ascendent: return layout.AscKeyword - case SqlSortDesc: + case Descendent: return layout.DescKeyword } return "" diff --git a/util/sqlgen/order_by_test.go b/util/sqlgen/order_by_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bbb7ac8421c99406c7d54765a3db5137c2488ffa --- /dev/null +++ b/util/sqlgen/order_by_test.go @@ -0,0 +1,143 @@ +package sqlgen + +import ( + "testing" +) + +func TestOrderBy(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + + s := o.Compile(defaultTemplate) + e := `ORDER BY "foo"` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestOrderByDesc(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + ), + ) + + s := o.Compile(defaultTemplate) + e := `ORDER BY "foo" DESC` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkOrderBy(b *testing.B) { + for i := 0; i < b.N; i++ { + JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + } +} + +func BenchmarkOrderByHash(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + for i := 0; i < b.N; i++ { + o.Hash() + } +} + +func BenchmarkCompileOrderByCompile(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + for i := 0; i < b.N; i++ { + o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompile(b *testing.B) { + o := Descendent + for i := 0; i < b.N; i++ { + o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := Descendent + o.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnHash(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnCompile(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + for i := 0; i < b.N; i++ { + s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := &SortColumn{Column: &Column{Name: "foo"}} + s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsHash(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnsCompile(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + for i := 0; i < b.N; i++ { + s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + s.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/raw.go b/util/sqlgen/raw.go index cda0e66bd25c0805badfdf648b7f5b7ab6481fcf..ca16c26c1f767c1d06b4cdca12ca83f0f94adf1f 100644 --- a/util/sqlgen/raw.go +++ b/util/sqlgen/raw.go @@ -1,17 +1,38 @@ package sqlgen +import ( + "fmt" +) + +var ( + _ = fmt.Stringer(&Raw{}) +) + +// Raw represents a value that is meant to be used in a query without escaping. type Raw struct { - Raw string + Value string // Value should not be modified after assigned. + hash string +} + +// RawValue creates and returns a new raw value. +func RawValue(v string) *Raw { + return &Raw{Value: v} } -func (self Raw) Hash() string { - return `Raw(` + self.Raw + `)` +// Hash returns a unique identifier. +func (r *Raw) Hash() string { + if r.hash == "" { + r.hash = `Raw{Value:"` + r.Value + `"}` + } + return r.hash } -func (self Raw) Compile(*Template) string { - return self.Raw +// Compile returns the raw value. +func (r *Raw) Compile(*Template) string { + return r.Value } -func (self Raw) String() string { - return self.Raw +// String returns the raw value. +func (r *Raw) String() string { + return r.Value } diff --git a/util/sqlgen/raw_test.go b/util/sqlgen/raw_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9ad57ad6c04a37761ddddc1532583b59a579cb63 --- /dev/null +++ b/util/sqlgen/raw_test.go @@ -0,0 +1,72 @@ +package sqlgen + +import ( + "fmt" + "testing" +) + +func TestRawString(t *testing.T) { + var s, e string + + raw := &Raw{Value: "foo"} + + s = raw.Compile(defaultTemplate) + e = `foo` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestRawCompile(t *testing.T) { + var s, e string + + raw := &Raw{Value: "foo"} + + s = raw.Compile(defaultTemplate) + e = `foo` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestRawHash(t *testing.T) { + var s, e string + + raw := &Raw{Value: "foo"} + + s = raw.Hash() + e = fmt.Sprintf(`Raw{Value:"%s"}`, raw) + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkRawCreate(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Raw{Value: "foo"} + } +} + +func BenchmarkRawString(b *testing.B) { + raw := &Raw{Value: "foo"} + for i := 0; i < b.N; i++ { + raw.String() + } +} + +func BenchmarkRawCompile(b *testing.B) { + raw := &Raw{Value: "foo"} + for i := 0; i < b.N; i++ { + raw.Compile(defaultTemplate) + } +} + +func BenchmarkRawHash(b *testing.B) { + raw := &Raw{Value: "foo"} + for i := 0; i < b.N; i++ { + raw.Hash() + } +} diff --git a/util/sqlgen/statement.go b/util/sqlgen/statement.go index b7b4ec472bcc2214f768b63cf854437307c8c002..16b8800d72a2805b6a29986cac5de4f7ceb9c2af 100644 --- a/util/sqlgen/statement.go +++ b/util/sqlgen/statement.go @@ -2,24 +2,29 @@ package sqlgen import ( "strconv" + "strings" + + "upper.io/cache" ) +// Statement represents different kinds of SQL statements. type Statement struct { Type - Table - Database + Table Fragment + Database Fragment Limit Offset - Columns - Values - ColumnValues - OrderBy - GroupBy + Columns Fragment + Values Fragment + ColumnValues Fragment + OrderBy Fragment + GroupBy Fragment Extra - Where + Where Fragment + hash string } -type statement_s struct { +type statementT struct { Table string Database string Limit @@ -33,64 +38,86 @@ type statement_s struct { Where string } -func (self Statement) Hash() string { - hash := `Statement(` + - strconv.Itoa(int(self.Type)) + `;` + - self.Table.Hash() + `;` + - self.Database.Hash() + `;` + - strconv.Itoa(int(self.Limit)) + `;` + - strconv.Itoa(int(self.Offset)) + `;` + - self.Columns.Hash() + `;` + - self.Values.Hash() + `;` + - self.ColumnValues.Hash() + `;` + - self.OrderBy.Hash() + `;` + - self.GroupBy.Hash() + `;` + - string(self.Extra) + `;` + - self.Where.Hash() + - `)` - return hash +func (layout *Template) doCompile(c Fragment) string { + if c != nil { + return c.Compile(layout) + } + return "" +} + +func (s Statement) getHash(h cache.Hashable) string { + if h != nil { + return h.Hash() + } + return "" +} + +// Hash returns a unique identifier. +func (s *Statement) Hash() string { + if s.hash == "" { + parts := strings.Join([]string{ + strconv.Itoa(int(s.Type)), + s.getHash(s.Table), + s.getHash(s.Database), + strconv.Itoa(int(s.Limit)), + strconv.Itoa(int(s.Offset)), + s.getHash(s.Columns), + s.getHash(s.Values), + s.getHash(s.ColumnValues), + s.getHash(s.OrderBy), + s.getHash(s.GroupBy), + string(s.Extra), + s.getHash(s.Where), + }, ";") + + s.hash = `Statement(` + parts + `)` + } + return s.hash } -func (self *Statement) Compile(layout *Template) (compiled string) { +// Compile transforms the Statement into an equivalent SQL query. +func (s *Statement) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(s); ok { + return z } - data := statement_s{ - Table: self.Table.Compile(layout), - Database: self.Database.Compile(layout), - Limit: self.Limit, - Offset: self.Offset, - Columns: self.Columns.Compile(layout), - Values: self.Values.Compile(layout), - ColumnValues: self.ColumnValues.Compile(layout), - OrderBy: self.OrderBy.Compile(layout), - GroupBy: self.GroupBy.Compile(layout), - Extra: string(self.Extra), - Where: self.Where.Compile(layout), + data := statementT{ + Table: layout.doCompile(s.Table), + Database: layout.doCompile(s.Database), + Limit: s.Limit, + Offset: s.Offset, + Columns: layout.doCompile(s.Columns), + Values: layout.doCompile(s.Values), + ColumnValues: layout.doCompile(s.ColumnValues), + OrderBy: layout.doCompile(s.OrderBy), + GroupBy: layout.doCompile(s.GroupBy), + Extra: string(s.Extra), + Where: layout.doCompile(s.Where), } - switch self.Type { - case SqlTruncate: + switch s.Type { + case Truncate: compiled = mustParse(layout.TruncateLayout, data) - case SqlDropTable: + case DropTable: compiled = mustParse(layout.DropTableLayout, data) - case SqlDropDatabase: + case DropDatabase: compiled = mustParse(layout.DropDatabaseLayout, data) - case SqlSelectCount: - compiled = mustParse(layout.SelectCountLayout, data) - case SqlSelect: + case Count: + compiled = mustParse(layout.CountLayout, data) + case Select: compiled = mustParse(layout.SelectLayout, data) - case SqlDelete: + case Delete: compiled = mustParse(layout.DeleteLayout, data) - case SqlUpdate: + case Update: compiled = mustParse(layout.UpdateLayout, data) - case SqlInsert: + case Insert: compiled = mustParse(layout.InsertLayout, data) + default: + panic("Unknown template type.") } - layout.Write(self, compiled) + layout.Write(s, compiled) return compiled } diff --git a/util/sqlgen/statement_test.go b/util/sqlgen/statement_test.go new file mode 100644 index 0000000000000000000000000000000000000000..34859a2d5ee18a0a56e647850aae60d027ed77d1 --- /dev/null +++ b/util/sqlgen/statement_test.go @@ -0,0 +1,757 @@ +package sqlgen + +import ( + "regexp" + "strings" + "testing" +) + +var ( + reInvisible = regexp.MustCompile(`[\t\n\r]`) + reSpace = regexp.MustCompile(`\s+`) +) + +func trim(a string) string { + a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") + a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") + return a +} + +func TestTruncateTable(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Truncate, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `TRUNCATE TABLE "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDropTable(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: DropTable, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `DROP TABLE "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDropDatabase(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: DropDatabase, + Database: &Database{Name: "table_name"}, + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `DROP DATABASE "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestCount(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Count, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT COUNT(1) AS _t FROM "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestCountRelation(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Count, + Table: TableWithName("information_schema.tables"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT COUNT(1) AS _t FROM "information_schema"."tables"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestCountWhere(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(RawValue("7"))}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFromAlias(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table"."name" AS "foo"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFromRawWhere(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + stmt = Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + &Raw{Value: "baz.id = exp.baz_id"}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectStarFromMany(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectArtistNameFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName("artist"), + Columns: JoinColumns( + &Column{Name: "artist.name"}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "artist"."name" FROM "artist"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectRawFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Table: TableWithName(`artist`), + Columns: JoinColumns( + &Column{Name: `artist.name`}, + &Column{Name: Raw{Value: `CONCAT(artist.name, " ", artist.last_name)`}}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWithLimitOffset(t *testing.T) { + var s, e string + var stmt Statement + + // LIMIT only. + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // OFFSET only. + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Offset: 17, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // LIMIT AND OFFSET. + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Offset: 17, + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestStatementGroupBy(t *testing.T) { + var s, e string + var stmt Statement + + // Simple GROUP BY + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWithOrderBy(t *testing.T) { + var s, e string + var stmt Statement + + // Simple ORDER BY + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY field ASC + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY field DESC + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY many fields + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + &SortColumn{Column: &Column{Name: "bar"}, Order: Ascendent}, + &SortColumn{Column: &Column{Name: "baz"}, Order: Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + // ORDER BY function + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: Raw{Value: "FOO()"}}, Order: Descendent}, + &SortColumn{Column: &Column{Name: Raw{Value: "BAR()"}}, Order: Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWhere(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + Limit: 10, + Offset: 23, + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestDelete(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Delete, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `DELETE FROM "table_name" WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestUpdate(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } + + stmt = Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + &ColumnValue{Column: &Column{Name: "bar"}, Operator: "=", Value: NewValue(Raw{Value: "88"})}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestInsert(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func TestInsertExtra(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Extra: "RETURNING id", + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + + s = trim(stmt.Compile(defaultTemplate)) + e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING id` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + +func BenchmarkStatementSimpleQuery(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + ), + } + + for i := 0; i < b.N; i++ { + _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementSimpleQueryHash(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + ), + } + + for i := 0; i < b.N; i++ { + _ = stmt.Hash() + } +} + +func BenchmarkStatementSimpleQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + ), + } + _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQuery(b *testing.B) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + + for i := 0; i < b.N; i++ { + _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValues( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: Raw{Value: "3"}}, + ), + } + _ = stmt.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/table.go b/util/sqlgen/table.go index 257f65ba8d500f995d0fbedf1d11ebe7ba5b0e68..43dcda62cf50ba9d85020915916aa211954c186f 100644 --- a/util/sqlgen/table.go +++ b/util/sqlgen/table.go @@ -13,6 +13,7 @@ type tableT struct { // Table struct represents a SQL table. type Table struct { Name interface{} + hash string } func quotedTableName(layout *Template, input string) string { @@ -33,7 +34,7 @@ func quotedTableName(layout *Template, input string) string { for i := range nameChunks { // nameChunks[i] = strings.TrimSpace(nameChunks[i]) nameChunks[i] = trimString(nameChunks[i]) - nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{nameChunks[i]}) + nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) } name = strings.Join(nameChunks, layout.ColumnSeparator) @@ -43,28 +44,44 @@ func quotedTableName(layout *Template, input string) string { if len(chunks) > 1 { // alias = strings.TrimSpace(chunks[1]) alias = trimString(chunks[1]) - alias = mustParse(layout.IdentifierQuote, Raw{alias}) + alias = mustParse(layout.IdentifierQuote, Raw{Value: alias}) } return mustParse(layout.TableAliasLayout, tableT{name, alias}) } +// TableWithName creates an returns a Table with the given name. +func TableWithName(name string) *Table { + return &Table{Name: name} +} + // Hash returns a string hash of the table value. -func (t Table) Hash() string { - switch t := t.Name.(type) { - case cc: - return `Table(` + t.Hash() + `)` - case string: - return `Table(` + t + `)` +func (t *Table) Hash() string { + if t.hash == "" { + var s string + + switch v := t.Name.(type) { + case Fragment: + s = v.Hash() + case fmt.Stringer: + s = v.String() + case string: + s = v + default: + s = fmt.Sprintf("%v", t.Name) + } + + t.hash = fmt.Sprintf(`Table{Name:%q}`, s) } - return fmt.Sprintf(`Table(%v)`, t.Name) + + return t.hash } // Compile transforms a table struct into a SQL chunk. -func (t Table) Compile(layout *Template) (compiled string) { +func (t *Table) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(t); ok { - return c + if z, ok := layout.Read(t); ok { + return z } switch value := t.Name.(type) { diff --git a/util/sqlgen/table_test.go b/util/sqlgen/table_test.go index dbae8be34841b8255a383c519fbd32d8b0444f43..df9528de00911d4e5d1eb11ab99526db93c55d11 100644 --- a/util/sqlgen/table_test.go +++ b/util/sqlgen/table_test.go @@ -6,9 +6,8 @@ import ( func TestTableSimple(t *testing.T) { var s, e string - var table Table - table = Table{"artist"} + table := TableWithName("artist") s = trim(table.Compile(defaultTemplate)) e = `"artist"` @@ -20,9 +19,8 @@ func TestTableSimple(t *testing.T) { func TestTableCompound(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo"} + table := TableWithName("artist.foo") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo"` @@ -34,9 +32,8 @@ func TestTableCompound(t *testing.T) { func TestTableCompoundAlias(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo AS baz"} + table := TableWithName("artist.foo AS baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo" AS "baz"` @@ -48,9 +45,8 @@ func TestTableCompoundAlias(t *testing.T) { func TestTableImplicitAlias(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo baz"} + table := TableWithName("artist.foo baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo" AS "baz"` @@ -62,9 +58,8 @@ func TestTableImplicitAlias(t *testing.T) { func TestTableMultiple(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo, artist.bar, artist.baz"} + table := TableWithName("artist.foo, artist.bar, artist.baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo", "artist"."bar", "artist"."baz"` @@ -76,9 +71,8 @@ func TestTableMultiple(t *testing.T) { func TestTableMultipleAlias(t *testing.T) { var s, e string - var table Table - table = Table{"artist.foo AS foo, artist.bar as bar, artist.baz As baz"} + table := TableWithName("artist.foo AS foo, artist.bar as bar, artist.baz As baz") s = trim(table.Compile(defaultTemplate)) e = `"artist"."foo" AS "foo", "artist"."bar" AS "bar", "artist"."baz" AS "baz"` @@ -90,9 +84,8 @@ func TestTableMultipleAlias(t *testing.T) { func TestTableMinimal(t *testing.T) { var s, e string - var table Table - table = Table{"a"} + table := TableWithName("a") s = trim(table.Compile(defaultTemplate)) e = `"a"` @@ -104,9 +97,8 @@ func TestTableMinimal(t *testing.T) { func TestTableEmpty(t *testing.T) { var s, e string - var table Table - table = Table{""} + table := TableWithName("") s = trim(table.Compile(defaultTemplate)) e = `` @@ -115,3 +107,30 @@ func TestTableEmpty(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkTableWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = TableWithName("foo") + } +} + +func BenchmarkTableHash(b *testing.B) { + t := TableWithName("name") + for i := 0; i < b.N; i++ { + t.Hash() + } +} + +func BenchmarkTableCompile(b *testing.B) { + t := TableWithName("name") + for i := 0; i < b.N; i++ { + t.Compile(defaultTemplate) + } +} + +func BenchmarkTableCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + t := TableWithName("name") + t.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/template.go b/util/sqlgen/template.go index e3ff175cb0619793fe32f907581eb7de2e687104..bea487f39f26fb4061b9aeef3c81bd2973f7fba5 100644 --- a/util/sqlgen/template.go +++ b/util/sqlgen/template.go @@ -1,9 +1,41 @@ package sqlgen import ( + "bytes" + "text/template" + "upper.io/cache" ) +// Type is the type of SQL query the statement represents. +type Type uint + +// Values for Type. +const ( + Truncate = Type(iota) + DropTable + DropDatabase + Count + Insert + Select + Update + Delete +) + +type ( + // Limit represents the SQL limit in a query. + Limit int + // Offset represents the SQL offset in a query. + Offset int + // Extra represents any custom SQL that is to be appended to the query. + Extra string +) + +var ( + parsedTemplates = make(map[string]*template.Template) +) + +// Template is an SQL template. type Template struct { ColumnSeparator string IdentifierSeparator string @@ -16,6 +48,7 @@ type Template struct { DescKeyword string AscKeyword string DefaultOperator string + AssignmentOperator string ClauseGroup string ClauseOperator string ColumnValue string @@ -31,7 +64,22 @@ type Template struct { TruncateLayout string DropDatabaseLayout string DropTableLayout string - SelectCountLayout string + CountLayout string GroupByLayout string *cache.Cache } + +func mustParse(text string, data interface{}) string { + var b bytes.Buffer + var ok bool + + if _, ok = parsedTemplates[text]; !ok { + parsedTemplates[text] = template.Must(template.New("").Parse(text)) + } + + if err := parsedTemplates[text].Execute(&b, data); err != nil { + panic("There was an error compiling the following template:\n" + text + "\nError was: " + err.Error()) + } + + return b.String() +} diff --git a/util/sqlgen/util_test.go b/util/sqlgen/util_test.go deleted file mode 100644 index 0551cf34fea3e9d0e088a79b8d313daff426cf1e..0000000000000000000000000000000000000000 --- a/util/sqlgen/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package sqlgen - -import ( - "regexp" - "strings" -) - -var ( - reInvisible = regexp.MustCompile(`[\t\n\r]`) - reSpace = regexp.MustCompile(`\s+`) -) - -func trim(a string) string { - a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") - a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") - return a -} diff --git a/util/sqlgen/utilities.go b/util/sqlgen/utilities.go index 40bfe1e04ab17d458df54f2b5eac3f68f69e1f29..305ab209026dfac06483cfb5de8a78a0377eb97a 100644 --- a/util/sqlgen/utilities.go +++ b/util/sqlgen/utilities.go @@ -10,50 +10,61 @@ const ( stageClose ) -func isSpace(in byte) bool { +// isBlankSymbol returns true if the given byte is either space, tab, carriage +// return or newline. +func isBlankSymbol(in byte) bool { return in == ' ' || in == '\t' || in == '\r' || in == '\n' } -func trimString(in string) string { +// trimString returns a slice of s with a leading and trailing blank symbols +// (as defined by isBlankSymbol) removed. +func trimString(s string) string { - start, end := 0, len(in)-1 + // This conversion is rather slow. + // return string(trimBytes([]byte(s))) - // Where do we start cutting? - for ; start <= end; start++ { - if isSpace(in[start]) == false { - break - } + start, end := 0, len(s)-1 + + if end < start { + return "" } - // Where do we end cutting? - for ; end >= start; end-- { - if isSpace(in[end]) == false { - break + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return "" } } - return in[start : end+1] + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] } -func trimByte(in []byte) []byte { +// trimBytes returns a slice of s with a leading and trailing blank symbols (as +// defined by isBlankSymbol) removed. +func trimBytes(s []byte) []byte { - start, end := 0, len(in)-1 + start, end := 0, len(s)-1 - // Where do we start cutting? - for ; start <= end; start++ { - if isSpace(in[start]) == false { - break - } + if end < start { + return []byte{} } - // Where do we end cutting? - for ; end >= start; end-- { - if isSpace(in[end]) == false { - break + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return []byte{} } } - return in[start : end+1] + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] } /* @@ -95,15 +106,12 @@ func separateByComma(in string) (out []string) { // Separates by spaces, ignoring spaces too. func separateBySpace(in string) (out []string) { - l := len(in) - - if l == 0 { + if len(in) == 0 { return []string{""} } - out = make([]string, 0, l) - pre := strings.Split(in, " ") + out = make([]string, 0, len(pre)) for i := range pre { pre[i] = trimString(pre[i]) @@ -119,7 +127,7 @@ func separateByAS(in string) (out []string) { out = []string{} if len(in) < 6 { - // Min expression: "a AS b" + // The minimum expression with the AS keyword is "x AS y", 6 chars. return []string{in} } @@ -129,7 +137,7 @@ func separateByAS(in string) (out []string) { var end int for end = start; end <= lim; end++ { - if end > 3 && isSpace(in[end]) && isSpace(in[end-3]) { + if end > 3 && isBlankSymbol(in[end]) && isBlankSymbol(in[end-3]) { if (in[end-1] == 's' || in[end-1] == 'S') && (in[end-2] == 'a' || in[end-2] == 'A') { break } diff --git a/util/sqlgen/utilities_test.go b/util/sqlgen/utilities_test.go index 87cff171c458d7b80a817a5c58d1d9398dfba091..4d5ec550618b0759bcab511f92cbfcc06c67fcdc 100644 --- a/util/sqlgen/utilities_test.go +++ b/util/sqlgen/utilities_test.go @@ -5,45 +5,63 @@ import ( "regexp" "strings" "testing" + "unicode" ) -func TestUtilIsSpace(t *testing.T) { - if isSpace(' ') == false { +const ( + blankSymbol = ' ' + stringWithCommas = "Hello,,World!,Enjoy" + stringWithSpaces = " Hello World! Enjoy" + stringWithASKeyword = "table.Name AS myTableAlias" +) + +var ( + bytesWithLeadingBlanks = []byte(" Hello world! ") + stringWithLeadingBlanks = string(bytesWithLeadingBlanks) +) + +func TestUtilIsBlankSymbol(t *testing.T) { + if isBlankSymbol(' ') == false { t.Fail() } - if isSpace('\n') == false { + if isBlankSymbol('\n') == false { t.Fail() } - if isSpace('\t') == false { + if isBlankSymbol('\t') == false { t.Fail() } - if isSpace('\r') == false { + if isBlankSymbol('\r') == false { t.Fail() } - if isSpace('x') == true { + if isBlankSymbol('x') == true { t.Fail() } } -func TestUtilTrimByte(t *testing.T) { +func TestUtilTrimBytes(t *testing.T) { var trimmed []byte - trimmed = trimByte([]byte(" \t\nHello World! \n")) + trimmed = trimBytes([]byte(" \t\nHello World! \n")) if string(trimmed) != "Hello World!" { t.Fatalf("Got: %s\n", string(trimmed)) } - trimmed = trimByte([]byte("Nope")) + trimmed = trimBytes([]byte("Nope")) if string(trimmed) != "Nope" { t.Fatalf("Got: %s\n", string(trimmed)) } - trimmed = trimByte([]byte("")) + trimmed = trimBytes([]byte("")) if string(trimmed) != "" { t.Fatalf("Got: %s\n", string(trimmed)) } - trimmed = trimByte(nil) + trimmed = trimBytes([]byte(" ")) + if string(trimmed) != "" { + t.Fatalf("Got: %s\n", string(trimmed)) + } + + trimmed = trimBytes(nil) if string(trimmed) != "" { t.Fatalf("Got: %s\n", string(trimmed)) } @@ -191,81 +209,76 @@ func TestUtilSeparateByAS(t *testing.T) { } } -func BenchmarkUtilIsSpace(b *testing.B) { +func BenchmarkUtilIsBlankSymbol(b *testing.B) { for i := 0; i < b.N; i++ { - _ = isSpace(' ') + _ = isBlankSymbol(blankSymbol) } } -func BenchmarkUtilTrimByte(b *testing.B) { - s := []byte(" Hello world! ") +func BenchmarkUtilStdlibIsBlankSymbol(b *testing.B) { for i := 0; i < b.N; i++ { - _ = trimByte(s) + _ = unicode.IsSpace(blankSymbol) } } -func BenchmarkUtilTrimString(b *testing.B) { - s := " Hello world! " +func BenchmarkUtilTrimBytes(b *testing.B) { for i := 0; i < b.N; i++ { - _ = trimString(s) + _ = trimBytes(bytesWithLeadingBlanks) + } +} +func BenchmarkUtilStdlibBytesTrimSpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = bytes.TrimSpace(bytesWithLeadingBlanks) } } -func BenchmarkUtilStdBytesTrimSpace(b *testing.B) { - s := []byte(" Hello world! ") +func BenchmarkUtilTrimString(b *testing.B) { for i := 0; i < b.N; i++ { - _ = bytes.TrimSpace(s) + _ = trimString(stringWithLeadingBlanks) } } -func BenchmarkUtilStdStringsTrimSpace(b *testing.B) { - s := " Hello world! " +func BenchmarkUtilStdlibStringsTrimSpace(b *testing.B) { for i := 0; i < b.N; i++ { - _ = strings.TrimSpace(s) + _ = strings.TrimSpace(stringWithLeadingBlanks) } } func BenchmarkUtilSeparateByComma(b *testing.B) { - s := "Hello,,World!,Enjoy" for i := 0; i < b.N; i++ { - _ = separateByComma(s) + _ = separateByComma(stringWithCommas) } } -func BenchmarkUtilSeparateBySpace(b *testing.B) { - s := " Hello World! Enjoy" +func BenchmarkUtilRegExpSeparateByComma(b *testing.B) { + sep := regexp.MustCompile(`\s*?,\s*?`) for i := 0; i < b.N; i++ { - _ = separateBySpace(s) + _ = sep.Split(stringWithCommas, -1) } } -func BenchmarkUtilSeparateByAS(b *testing.B) { - s := "table.Name AS myTableAlias" +func BenchmarkUtilSeparateBySpace(b *testing.B) { for i := 0; i < b.N; i++ { - _ = separateByAS(s) + _ = separateBySpace(stringWithSpaces) } } -func BenchmarkUtilSeparateByCommaRegExp(b *testing.B) { - sep := regexp.MustCompile(`\s*?,\s*?`) - s := "Hello,,World!,Enjoy" +func BenchmarkUtilRegExpSeparateBySpace(b *testing.B) { + sep := regexp.MustCompile(`\s+`) for i := 0; i < b.N; i++ { - _ = sep.Split(s, -1) + _ = sep.Split(stringWithSpaces, -1) } } -func BenchmarkUtilSeparateBySpaceRegExp(b *testing.B) { - sep := regexp.MustCompile(`\s+`) - s := " Hello World! Enjoy" +func BenchmarkUtilSeparateByAS(b *testing.B) { for i := 0; i < b.N; i++ { - _ = sep.Split(s, -1) + _ = separateByAS(stringWithASKeyword) } } -func BenchmarkUtilSeparateByASRegExp(b *testing.B) { +func BenchmarkUtilRegExpSeparateByAS(b *testing.B) { sep := regexp.MustCompile(`(?i:\s+AS\s+)`) - s := "table.Name AS myTableAlias" for i := 0; i < b.N; i++ { - _ = sep.Split(s, -1) + _ = sep.Split(stringWithASKeyword, -1) } } diff --git a/util/sqlgen/value.go b/util/sqlgen/value.go index 5807578d0c10a8884098c0a0b93d8595272a370c..eb224e237078f3de0a42dff6a929746ee2251957 100644 --- a/util/sqlgen/value.go +++ b/util/sqlgen/value.go @@ -1,86 +1,119 @@ package sqlgen import ( - "database/sql/driver" + //"database/sql/driver" "fmt" - "log" + //"log" "strings" ) -type Values []Value +// Values represents an array of Value. +type Values struct { + Values []Fragment + hash string +} +// Value represents an escaped SQL value. type Value struct { - Val interface{} + V interface{} + hash string +} + +// NewValue creates and returns a Value. +func NewValue(v interface{}) *Value { + return &Value{V: v} +} + +// JoinValues creates and returns an array of values. +func JoinValues(v ...Fragment) *Values { + return &Values{Values: v} } -func (self Value) Hash() string { - switch t := self.Val.(type) { - case cc: - return `Value(` + t.Hash() + `)` - case string: - return `Value(` + t + `)` +// Hash returns a unique identifier. +func (v *Value) Hash() string { + if v.hash == "" { + switch t := v.V.(type) { + case Fragment: + v.hash = `Value(` + t.Hash() + `)` + case string: + v.hash = `Value(` + t + `)` + default: + v.hash = fmt.Sprintf(`Value(%v)`, v.V) + } } - return fmt.Sprintf(`Value(%v)`, self.Val) + return v.hash } -func (self Value) Compile(layout *Template) (compiled string) { +// Compile transforms the Value into an equivalent SQL representation. +func (v *Value) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c + if z, ok := layout.Read(v); ok { + return z } - if raw, ok := self.Val.(Raw); ok { - compiled = raw.Raw + if raw, ok := v.V.(Raw); ok { + compiled = raw.Compile(layout) + } else if raw, ok := v.V.(Fragment); ok { + compiled = raw.Compile(layout) } else { - compiled = mustParse(layout.ValueQuote, Raw{fmt.Sprintf(`%v`, self.Val)}) + compiled = mustParse(layout.ValueQuote, RawValue(fmt.Sprintf(`%v`, v.V))) } - layout.Write(self, compiled) + layout.Write(v, compiled) return } -func (self Value) Scan(src interface{}) error { - log.Println("Scan(", src, ") on", self.Val) +/* +func (v *Value) Scan(src interface{}) error { + log.Println("Scan(", src, ") on", v.V) return nil } -func (self Value) Value() (driver.Value, error) { - log.Println("Value() on", self.Val) - return self.Val, nil +func (v *Value) Value() (driver.Value, error) { + log.Println("Value() on", v.V) + return v.V, nil } - -func (self Values) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +*/ + +// Hash returns a unique identifier. +func (vs *Values) Hash() string { + if vs.hash == "" { + hash := make([]string, len(vs.Values)) + for i := range vs.Values { + hash[i] = vs.Values[i].Hash() + } + vs.hash = `Values(` + strings.Join(hash, `,`) + `)` } - return `Values(` + strings.Join(hash, `,`) + `)` + return vs.hash } -func (self Values) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { +// Compile transforms the Values into an equivalent SQL representation. +func (vs *Values) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(vs); ok { return c } - l := len(self) + l := len(vs.Values) if l > 0 { chunks := make([]string, 0, l) for i := 0; i < l; i++ { - chunks = append(chunks, self[i].Compile(layout)) + chunks = append(chunks, vs.Values[i].Compile(layout)) } compiled = strings.Join(chunks, layout.ValueSeparator) } - layout.Write(self, compiled) + layout.Write(vs, compiled) return } -func (self Values) Scan(src interface{}) error { +/* +func (vs Values) Scan(src interface{}) error { log.Println("Values.Scan(", src, ")") return nil } -func (self Values) Value() (driver.Value, error) { +func (vs Values) Value() (driver.Value, error) { log.Println("Values.Value()") - return self, nil + return vs, nil } +*/ diff --git a/util/sqlgen/value_test.go b/util/sqlgen/value_test.go index 1c5b5e30b4ce1406b2f3e49eff278c2abb9addd3..8c621700d7a2327fa65cb687468d52e3258a0218 100644 --- a/util/sqlgen/value_test.go +++ b/util/sqlgen/value_test.go @@ -6,9 +6,9 @@ import ( func TestValue(t *testing.T) { var s, e string - var val Value + var val *Value - val = Value{1} + val = NewValue(1) s = val.Compile(defaultTemplate) e = `'1'` @@ -17,7 +17,7 @@ func TestValue(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } - val = Value{Raw{"NOW()"}} + val = NewValue(&Raw{Value: "NOW()"}) s = val.Compile(defaultTemplate) e = `NOW()` @@ -29,13 +29,12 @@ func TestValue(t *testing.T) { func TestValues(t *testing.T) { var s, e string - var val Values - val = Values{ - Value{Raw{"1"}}, - Value{Raw{"2"}}, - Value{"3"}, - } + val := JoinValues( + &Value{V: &Raw{Value: "1"}}, + &Value{V: &Raw{Value: "2"}}, + &Value{V: "3"}, + ) s = val.Compile(defaultTemplate) e = `1, 2, '3'` @@ -44,3 +43,57 @@ func TestValues(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewValue("a") + } +} + +func BenchmarkValueHash(b *testing.B) { + v := NewValue("a") + for i := 0; i < b.N; i++ { + _ = v.Hash() + } +} + +func BenchmarkValueCompile(b *testing.B) { + v := NewValue("a") + for i := 0; i < b.N; i++ { + _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + v := NewValue("a") + _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinValues(NewValue("a"), NewValue("b")) + } +} + +func BenchmarkValuesHash(b *testing.B) { + vs := JoinValues(NewValue("a"), NewValue("b")) + for i := 0; i < b.N; i++ { + _ = vs.Hash() + } +} + +func BenchmarkValuesCompile(b *testing.B) { + vs := JoinValues(NewValue("a"), NewValue("b")) + for i := 0; i < b.N; i++ { + _ = vs.Compile(defaultTemplate) + } +} + +func BenchmarkValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + vs := JoinValues(NewValue("a"), NewValue("b")) + _ = vs.Compile(defaultTemplate) + } +} diff --git a/util/sqlgen/where.go b/util/sqlgen/where.go index bcc9b677649258ac887fea097f5cab3b8f3eb3d3..64ea9f106424231e3c38dd73c405cbb26a0feb9e 100644 --- a/util/sqlgen/where.go +++ b/util/sqlgen/where.go @@ -1,84 +1,110 @@ package sqlgen import ( + "fmt" "strings" ) -type ( - Or []cc - And []cc - Where []cc -) +// Or represents an SQL OR operator. +type Or Where + +// And represents an SQL AND operator. +type And Where + +// Where represents an SQL WHERE clause. +type Where struct { + Conditions []Fragment + hash string +} type conds struct { Conds string } -func (self Or) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `Or(` + strings.Join(hash, `,`) + `)` +// WhereConditions creates and retuens a new Where. +func WhereConditions(conditions ...Fragment) *Where { + return &Where{Conditions: conditions} } -func (self Or) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c - } +// JoinWithOr creates and returns a new Or. +func JoinWithOr(conditions ...Fragment) *Or { + return &Or{Conditions: conditions} +} - compiled = groupCondition(layout, self, mustParse(layout.ClauseOperator, layout.OrKeyword)) +// JoinWithAnd creates and returns a new And. +func JoinWithAnd(conditions ...Fragment) *And { + return &And{Conditions: conditions} +} - layout.Write(self, compiled) +// Hash returns a unique identifier. +func (w *Where) Hash() string { + if w.hash == "" { + hash := make([]string, len(w.Conditions)) + for i := range w.Conditions { + hash[i] = w.Conditions[i].Hash() + } + w.hash = fmt.Sprintf(`Where{%s}`, strings.Join(hash, `, `)) + } + return w.hash +} - return +// Hash returns a unique identifier. +func (o *Or) Hash() string { + w := Where(*o) + return `Or(` + w.Hash() + `)` } -func (self And) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) - } - return `And(` + strings.Join(hash, `,`) + `)` +// Hash returns a unique identifier. +func (a *And) Hash() string { + w := Where(*a) + return `Or(` + w.Hash() + `)` } -func (self And) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { - return c +// Compile transforms the Or into an equivalent SQL representation. +func (o *Or) Compile(layout *Template) (compiled string) { + + if z, ok := layout.Read(o); ok { + return z } - compiled = groupCondition(layout, self, mustParse(layout.ClauseOperator, layout.AndKeyword)) + compiled = groupCondition(layout, o.Conditions, mustParse(layout.ClauseOperator, layout.OrKeyword)) - layout.Write(self, compiled) + layout.Write(o, compiled) return } -func (self Where) Hash() string { - hash := make([]string, 0, len(self)) - for i := range self { - hash = append(hash, self[i].Hash()) +// Compile transforms the And into an equivalent SQL representation. +func (a *And) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(a); ok { + return c } - return `Where(` + strings.Join(hash, `,`) + `)` + + compiled = groupCondition(layout, a.Conditions, mustParse(layout.ClauseOperator, layout.AndKeyword)) + + layout.Write(a, compiled) + + return } -func (self Where) Compile(layout *Template) (compiled string) { - if c, ok := layout.Read(self); ok { +// Compile transforms the Where into an equivalent SQL representation. +func (w *Where) Compile(layout *Template) (compiled string) { + if c, ok := layout.Read(w); ok { return c } - grouped := groupCondition(layout, self, mustParse(layout.ClauseOperator, layout.AndKeyword)) + grouped := groupCondition(layout, w.Conditions, mustParse(layout.ClauseOperator, layout.AndKeyword)) if grouped != "" { compiled = mustParse(layout.WhereLayout, conds{grouped}) } - layout.Write(self, compiled) + layout.Write(w, compiled) return } -func groupCondition(layout *Template, terms []cc, joinKeyword string) string { +func groupCondition(layout *Template, terms []Fragment, joinKeyword string) string { l := len(terms) chunks := make([]string, 0, l) diff --git a/util/sqlgen/where_test.go b/util/sqlgen/where_test.go index 111f79b41d44ed2c8f0a837d91b4c2f31f44b342..c4b2a182ac7e1cadcfdc216f029f529c4864290f 100644 --- a/util/sqlgen/where_test.go +++ b/util/sqlgen/where_test.go @@ -6,13 +6,12 @@ import ( func TestWhereAnd(t *testing.T) { var s, e string - var and And - and = And{ - ColumnValue{Column{"id"}, ">", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "<", Value{Raw{"99"}}}, - ColumnValue{Column{"name"}, "=", Value{"John"}}, - } + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + ) s = and.Compile(defaultTemplate) e = `("id" > 8 AND "id" < 99 AND "name" = 'John')` @@ -24,12 +23,11 @@ func TestWhereAnd(t *testing.T) { func TestWhereOr(t *testing.T) { var s, e string - var or Or - or = Or{ - ColumnValue{Column{"id"}, "=", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "=", Value{Raw{"99"}}}, - } + or := JoinWithOr( + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "99"})}, + ) s = or.Compile(defaultTemplate) e = `("id" = 8 OR "id" = 99)` @@ -41,17 +39,16 @@ func TestWhereOr(t *testing.T) { func TestWhereAndOr(t *testing.T) { var s, e string - var and And - - and = And{ - ColumnValue{Column{"id"}, ">", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "<", Value{Raw{"99"}}}, - ColumnValue{Column{"name"}, "=", Value{"John"}}, - Or{ - ColumnValue{Column{"last_name"}, "=", Value{"Smith"}}, - ColumnValue{Column{"last_name"}, "=", Value{"Reyes"}}, - }, - } + + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + ) s = and.Compile(defaultTemplate) e = `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))` @@ -63,24 +60,23 @@ func TestWhereAndOr(t *testing.T) { func TestWhereAndRawOrAnd(t *testing.T) { var s, e string - var where Where - - where = Where{ - And{ - ColumnValue{Column{"id"}, ">", Value{Raw{"8"}}}, - ColumnValue{Column{"id"}, "<", Value{Raw{"99"}}}, - }, - ColumnValue{Column{"name"}, "=", Value{"John"}}, - Raw{"city_id = 728"}, - Or{ - ColumnValue{Column{"last_name"}, "=", Value{"Smith"}}, - ColumnValue{Column{"last_name"}, "=", Value{"Reyes"}}, - }, - And{ - ColumnValue{Column{"age"}, ">", Value{Raw{"18"}}}, - ColumnValue{Column{"age"}, "<", Value{Raw{"41"}}}, - }, - } + + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) s = trim(where.Compile(defaultTemplate)) e = `WHERE (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))` @@ -89,3 +85,29 @@ func TestWhereAndRawOrAnd(t *testing.T) { t.Fatalf("Got: %s, Expecting: %s", s, e) } } + +func BenchmarkWhere(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + } +} + +func BenchmarkCompileWhere(b *testing.B) { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + for i := 0; i < b.N; i++ { + w.Compile(defaultTemplate) + } +} + +func BenchmarkCompileWhereNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + w.Compile(defaultTemplate) + } +} diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go new file mode 100644 index 0000000000000000000000000000000000000000..d7bc224c8e75bf632986ed8ea87513644dc238e0 --- /dev/null +++ b/util/sqlutil/convert.go @@ -0,0 +1,201 @@ +package sqlutil + +import ( + "fmt" + "reflect" + "strings" + "upper.io/db" + "upper.io/db/util/sqlgen" +) + +var ( + sqlPlaceholder = sqlgen.RawValue(`?`) + sqlNull = sqlgen.RawValue(`NULL`) +) + +type TemplateWithUtils struct { + *sqlgen.Template +} + +func NewTemplateWithUtils(template *sqlgen.Template) *TemplateWithUtils { + return &TemplateWithUtils{template} +} + +// ToWhereWithArguments converts the given db.Cond parameters into a sqlgen.Where +// value. +func (tu *TemplateWithUtils) ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + for i := range t { + w, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case db.And: + var op sqlgen.And + for i := range t { + k, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + op.Conditions = append(op.Conditions, k.Conditions...) + } + where.Conditions = append(where.Conditions, &op) + return + case db.Or: + var op sqlgen.Or + for i := range t { + w, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + op.Conditions = append(op.Conditions, w.Conditions...) + } + where.Conditions = append(where.Conditions, &op) + return + case db.Raw: + if s, ok := t.Value.(string); ok { + where.Conditions = append(where.Conditions, sqlgen.RawValue(s)) + } + return + case db.Cond: + cv, v := tu.ToColumnValues(t) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + return + case db.Constrainer: + cv, v := tu.ToColumnValues(t.Constraint()) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + return + } + + panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), term)) +} + +// ToInterfaceArguments converts the given value into an array of interfaces. +func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []interface{}) { + if value == nil { + return nil + } + + v := reflect.ValueOf(value) + + switch v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() + } + + return args + } + return nil + default: + args = []interface{}{value} + } + + return args +} + +// ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct. +func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { + + args = []interface{}{} + + for column, value := range cond { + columnValue := sqlgen.ColumnValue{} + + // Guessing operator from input, or using a default one. + column := strings.TrimSpace(column) + chunks := strings.SplitN(column, ` `, 2) + + columnValue.Column = sqlgen.ColumnWithName(chunks[0]) + + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } else { + columnValue.Operator = tu.DefaultOperator + } + + switch value := value.(type) { + case db.Func: + v := tu.ToInterfaceArguments(value.Args) + columnValue.Operator = value.Name + + if v == nil { + // A function with no arguments. + columnValue.Value = sqlgen.RawValue(`()`) + } else { + // A function with one or more arguments. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } + + args = append(args, v...) + default: + v := tu.ToInterfaceArguments(value) + + l := len(v) + if v == nil || l == 0 { + // Nil value given. + columnValue.Value = sqlNull + } else { + if l > 1 { + // Array value given. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } else { + // Single value given. + columnValue.Value = sqlPlaceholder + } + args = append(args, v...) + } + } + + ToColumnValues.ColumnValues = append(ToColumnValues.ColumnValues, &columnValue) + } + + return ToColumnValues, args +} + +// ToColumnsValuesAndArguments maps the given columnNames and columnValues into +// sqlgen's Columns and Values, it also extracts and returns query arguments. +func (tu *TemplateWithUtils) ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(sqlgen.Columns) + + columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i])) + } + + values := new(sqlgen.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]sqlgen.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *sqlgen.Value: + // Adding value. + values.Values = append(values.Values, v) + case sqlgen.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} diff --git a/util/sqlutil/debug.go b/util/sqlutil/debug.go index 08d8ebe9f9f03bad1ec16b428edcdf515f80e423..f82a98577de0c1eef0661aedbda9050c62e31413 100644 --- a/util/sqlutil/debug.go +++ b/util/sqlutil/debug.go @@ -24,7 +24,10 @@ package sqlutil import ( "fmt" "log" + "os" "strings" + + "upper.io/db" ) // Debug is used for printing SQL queries and arguments. @@ -59,3 +62,17 @@ func (d *Debug) Print() { log.Printf("\n\t%s\n\n", strings.Join(s, "\n\t")) } + +func IsDebugEnabled() bool { + if os.Getenv(db.EnvEnableDebug) != "" { + return true + } + return false +} + +func Log(query string, args []interface{}, err error, start int64, end int64) { + if IsDebugEnabled() { + d := Debug{query, args, err, start, end} + d.Print() + } +} diff --git a/util/sqlutil/fetch.go b/util/sqlutil/fetch.go index cd75d147eed1066e2d4e86690bddb048b8b30e76..73dd137678940bdbc893a93a68d66f041bb77527 100644 --- a/util/sqlutil/fetch.go +++ b/util/sqlutil/fetch.go @@ -22,7 +22,7 @@ package sqlutil import ( - "errors" + "encoding/json" "reflect" "github.com/jmoiron/sqlx" @@ -150,16 +150,101 @@ func fetchResult(itemT reflect.Type, rows *sqlx.Rows, columns []string) (reflect case reflect.Struct: values := make([]interface{}, len(columns)) - fields := rows.Mapper.TraversalsByName(itemT, columns) + typeMap := rows.Mapper.TypeMap(itemT) + fieldMap := typeMap.Names + wrappedValues := map[reflect.Value][]interface{}{} - if err = fieldsByTraversal(item, fields, values, true); err != nil { - return item, err + for i, k := range columns { + fi, ok := fieldMap[k] + if !ok { + values[i] = new(interface{}) + continue + } + + f := reflectx.FieldByIndexesReadOnly(item, fi.Index) + + // TODO: refactor into a nice pattern + if _, ok := fi.Options["stringarray"]; ok { + values[i] = &[]byte{} + wrappedValues[f] = []interface{}{"stringarray", values[i]} + } else if _, ok := fi.Options["int64array"]; ok { + values[i] = &[]byte{} + wrappedValues[f] = []interface{}{"int64array", values[i]} + } else if _, ok := fi.Options["jsonb"]; ok { + values[i] = &[]byte{} + wrappedValues[f] = []interface{}{"jsonb", values[i]} + } else { + values[i] = f.Addr().Interface() + } + + if u, ok := values[i].(db.Unmarshaler); ok { + values[i] = scanner{u} + } } + // Scanner - for reads + // Valuer - for writes + + // OptionTypes + // - before/after scan + // - before/after valuer.. + if err = rows.Scan(values...); err != nil { return item, err } + // TODO: move this stuff out of here.. find a nice pattern + for f, v := range wrappedValues { + opt := v[0].(string) + b := v[1].(*[]byte) + + switch opt { + case "stringarray": + v := StringArray{} + err := v.Scan(*b) + if err != nil { + return item, err + } + f.Set(reflect.ValueOf(v)) + case "int64array": + v := Int64Array{} + err := v.Scan(*b) + if err != nil { + return item, err + } + f.Set(reflect.ValueOf(v)) + case "jsonb": + if len(*b) == 0 { + continue + } + + var vv reflect.Value + t := reflect.PtrTo(f.Type()) + + switch t.Kind() { + case reflect.Map: + vv = reflect.MakeMap(t) + case reflect.Slice: + vv = reflect.MakeSlice(t, 0, 0) + default: + vv = reflect.New(t) + } + + err := json.Unmarshal(*b, vv.Interface()) + if err != nil { + return item, err + } + + vv = vv.Elem().Elem() + + if !vv.IsValid() || (vv.Kind() == reflect.Ptr && vv.IsNil()) { + continue + } + + f.Set(vv) + } + } + case reflect.Map: columns, err := rows.Columns() @@ -188,34 +273,3 @@ func fetchResult(itemT reflect.Type, rows *sqlx.Rows, columns []string) (reflect return item, nil } - -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { - v = reflect.Indirect(v) - - if v.Kind() != reflect.Struct { - return errors.New("argument not a struct") - } - - for i, traversal := range traversals { - - if len(traversal) == 0 { - values[i] = new(interface{}) - continue - } - - f := reflectx.FieldByIndexes(v, traversal) - - if ptrs { - values[i] = f.Addr().Interface() - } else { - values[i] = f.Interface() - } - - // Provides compatibility with db.Unmarshaler - if u, ok := values[i].(db.Unmarshaler); ok { - values[i] = scanner{u} - } - - } - return nil -} diff --git a/sqlite/result.go b/util/sqlutil/result/result.go similarity index 58% rename from sqlite/result.go rename to util/sqlutil/result/result.go index 6504200c81894b9a1659678fc25db5b969c9b3f7..7badb0e750e2774ff0f511bb1018cccd8b624813 100644 --- a/sqlite/result.go +++ b/util/sqlutil/result/result.go @@ -19,7 +19,7 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package sqlite +package result import ( "fmt" @@ -31,12 +31,16 @@ import ( "upper.io/db/util/sqlutil" ) +var ( + sqlPlaceholder = sqlgen.RawValue(`?`) +) + type counter struct { Total uint64 `db:"_t"` } -type result struct { - table *table +type Result struct { + table DataProvider cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. limit sqlgen.Limit offset sqlgen.Offset @@ -45,63 +49,73 @@ type result struct { orderBy sqlgen.OrderBy groupBy sqlgen.GroupBy arguments []interface{} + template *sqlutil.TemplateWithUtils +} + +// NewResult creates and results a new result set on the given table, this set +// is limited by the given sqlgen.Where conditions. +func NewResult(template *sqlutil.TemplateWithUtils, p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { + return &Result{ + table: p, + where: where, + arguments: arguments, + template: template, + } } // Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { +func (r *Result) setCursor() error { var err error // We need a cursor, if the cursor does not exists yet then we create one. if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, + r.cursor, err = r.table.Query(sqlgen.Statement{ + Type: sqlgen.Select, + Table: sqlgen.TableWithName(r.table.Name()), + Columns: &r.columns, Limit: r.limit, Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, + Where: &r.where, + OrderBy: &r.orderBy, + GroupBy: &r.groupBy, }, r.arguments...) } return err } // Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) +func (r *Result) Where(terms ...interface{}) db.Result { + r.where, r.arguments = r.template.ToWhereWithArguments(terms) return r } // Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { +func (r *Result) Limit(n uint) db.Result { r.limit = sqlgen.Limit(n) return r } // Determines how many documents will be skipped before starting to grab // results. -func (r *result) Skip(n uint) db.Result { +func (r *Result) Skip(n uint) db.Result { r.offset = sqlgen.Offset(n) return r } // Used to group results that have the same value in the same column or // columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) +func (r *Result) Group(fields ...interface{}) db.Result { + var columns []sqlgen.Fragment - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) + for i := range fields { + switch v := fields[i].(type) { + case string: + columns = append(columns, sqlgen.ColumnWithName(v)) + case sqlgen.Fragment: + columns = append(columns, v) } } - r.groupBy = groupByColumns + r.groupBy = *sqlgen.GroupByColumns(columns...) return r } @@ -109,54 +123,52 @@ func (r *result) Group(fields ...interface{}) db.Result { // Determines sorting of results according to the provided names. Fields may be // prefixed by - (minus) which means descending order, ascending order would be // used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { +func (r *Result) Sort(fields ...interface{}) db.Result { - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) + var sortColumns sqlgen.SortColumns - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn + for i := range fields { + var sort *sqlgen.SortColumn switch value := fields[i].(type) { case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)), + Order: sqlgen.Ascendent, } case string: if strings.HasPrefix(value, `-`) { // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value[1:]), + Order: sqlgen.Descendent, } } else { // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value), + Order: sqlgen.Ascendent, } } } - sortColumns = append(sortColumns, sort) + sortColumns.Columns = append(sortColumns.Columns, sort) } - r.orderBy.SortColumns = sortColumns + r.orderBy.SortColumns = &sortColumns return r } // Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { +func (r *Result) Select(fields ...interface{}) db.Result { - r.columns = make(sqlgen.Columns, 0, len(fields)) + r.columns = sqlgen.Columns{} - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column + for i := range fields { + var col sqlgen.Fragment switch value := fields[i].(type) { case db.Func: - v := interfaceArgs(value.Args) + v := r.template.ToInterfaceArguments(value.Args) var s string if len(v) == 0 { s = fmt.Sprintf(`%s()`, value.Name) @@ -167,20 +179,20 @@ func (r *result) Select(fields ...interface{}) db.Result { } s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) } - col = sqlgen.Column{sqlgen.Raw{s}} + col = sqlgen.RawValue(s) case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} + col = sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)) default: - col = sqlgen.Column{value} + col = sqlgen.ColumnWithName(fmt.Sprintf(`%v`, value)) } - r.columns = append(r.columns, col) + r.columns.Columns = append(r.columns.Columns, col) } return r } // Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { +func (r *Result) All(dst interface{}) error { var err error if r.cursor != nil { @@ -203,7 +215,7 @@ func (r *result) All(dst interface{}) error { } // Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { +func (r *Result) One(dst interface{}) error { var err error if r.cursor != nil { @@ -218,7 +230,7 @@ func (r *result) One(dst interface{}) error { } // Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { +func (r *Result) Next(dst interface{}) (err error) { if err = r.setCursor(); err != nil { r.Close() @@ -234,13 +246,13 @@ func (r *result) Next(dst interface{}) (err error) { } // Removes the matching items from the collection. -func (r *result) Remove() error { +func (r *Result) Remove() error { var err error - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, + _, err = r.table.Exec(sqlgen.Statement{ + Type: sqlgen.Delete, + Table: sqlgen.TableWithName(r.table.Name()), + Where: &r.where, }, r.arguments...) return err @@ -249,35 +261,33 @@ func (r *result) Remove() error { // Updates matching items from the collection with values of the given map or // struct. -func (r *result) Update(values interface{}) error { +func (r *Result) Update(values interface{}) error { ff, vv, err := r.table.FieldValues(values) if err != nil { return err } - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) + cvs := new(sqlgen.ColumnValues) - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) + for i := range ff { + cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: r.template.AssignmentOperator, Value: sqlPlaceholder}) } vv = append(vv, r.arguments...) - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, + _, err = r.table.Exec(sqlgen.Statement{ + Type: sqlgen.Update, + Table: sqlgen.TableWithName(r.table.Name()), ColumnValues: cvs, - Where: r.where, + Where: &r.where, }, vv...) return err } // Closes the result set. -func (r *result) Close() (err error) { +func (r *Result) Close() (err error) { if r.cursor != nil { err = r.cursor.Close() r.cursor = nil @@ -286,13 +296,13 @@ func (r *result) Close() (err error) { } // Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { +func (r *Result) Count() (uint64, error) { var count counter - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, + row, err := r.table.QueryRow(sqlgen.Statement{ + Type: sqlgen.Count, + Table: sqlgen.TableWithName(r.table.Name()), + Where: &r.where, }, r.arguments...) if err != nil { diff --git a/util/sqlutil/result/table.go b/util/sqlutil/result/table.go new file mode 100644 index 0000000000000000000000000000000000000000..57ec42ac1bdf8abe90ea0667e7b8177089820c8d --- /dev/null +++ b/util/sqlutil/result/table.go @@ -0,0 +1,15 @@ +package result + +import ( + "database/sql" + "github.com/jmoiron/sqlx" + "upper.io/db/util/sqlgen" +) + +type DataProvider interface { + Name() string + Query(sqlgen.Statement, ...interface{}) (*sqlx.Rows, error) + QueryRow(sqlgen.Statement, ...interface{}) (*sqlx.Row, error) + Exec(sqlgen.Statement, ...interface{}) (sql.Result, error) + FieldValues(interface{}) ([]string, []interface{}, error) +} diff --git a/util/sqlutil/scanner.go b/util/sqlutil/scanner.go index 5baa5ca8d3b695171fd9ee10160b0c6e16f0c4ff..34bde7a0fe4fc2b2b8db6d9c6f32c9bda48cd132 100644 --- a/util/sqlutil/scanner.go +++ b/util/sqlutil/scanner.go @@ -23,6 +23,12 @@ package sqlutil import ( "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "strconv" + "strings" + "upper.io/db" ) @@ -35,3 +41,148 @@ func (u scanner) Scan(v interface{}) error { } var _ sql.Scanner = scanner{} + +//------ + +type JsonbType struct { + V interface{} +} + +func (j *JsonbType) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + v := JsonbType{} + if err := json.Unmarshal(b, &v.V); err != nil { + return err + } + *j = v + return nil +} + +func (j JsonbType) Value() (driver.Value, error) { + b, err := json.Marshal(j.V) + if err != nil { + return nil, err + } + return b, nil +} + +//------ + +type StringArray []string + +func (a *StringArray) Scan(src interface{}) error { + if src == nil { + *a = StringArray{} + return nil + } + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + if len(b) == 0 { + return nil + } + s := string(b)[1 : len(b)-1] + if s == "" { + return nil + } + results := strings.Split(s, ",") + *a = StringArray(results) + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedString(b, a[0]) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedString(b, a[i]) + } + + return append(b, '}'), nil + } + + return []byte{'{', '}'}, nil +} + +func appendArrayQuotedString(b []byte, v string) []byte { + b = append(b, '"') + for { + i := strings.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +//------ + +type Int64Array []int64 + +func (a *Int64Array) Scan(src interface{}) error { + if src == nil { + return nil + } + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + s := string(b)[1 : len(b)-1] + parts := strings.Split(s, ",") + results := make([]int64, 0) + for _, n := range parts { + i, err := strconv.ParseInt(n, 10, 64) + if err != nil { + return err + } + results = append(results, i) + } + *a = Int64Array(results) + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return append(b, '}'), nil + } + + return []byte{'{', '}'}, nil +} diff --git a/util/sqlutil/main.go b/util/sqlutil/sqlutil.go similarity index 63% rename from util/sqlutil/main.go rename to util/sqlutil/sqlutil.go index 77cc0637a4e5b25fdcca9bbd8327eedbd5a3bde8..4d2b54df5b3074dafa2e02561e149061c340be1e 100644 --- a/util/sqlutil/main.go +++ b/util/sqlutil/sqlutil.go @@ -23,20 +23,18 @@ package sqlutil import ( "database/sql" + "fmt" "reflect" "regexp" "strings" "github.com/jmoiron/sqlx/reflectx" - - "menteslibres.net/gosexy/to" - "upper.io/db" - "upper.io/db/util" ) var ( - reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) + reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`) ) var ( @@ -50,29 +48,19 @@ var ( // using FieldValues() type T struct { Columns []string + Mapper *reflectx.Mapper + Tables []string // Holds table names. } func (t *T) columnLike(s string) string { for _, name := range t.Columns { - if util.NormalizeColumn(s) == util.NormalizeColumn(name) { + if normalizeColumn(s) == normalizeColumn(name) { return name } } return s } -func marshal(v interface{}) (interface{}, error) { - - if m, isMarshaler := v.(db.Marshaler); isMarshaler { - var err error - if v, err = m.MarshalDB(); err != nil { - return nil, err - } - } - - return v, nil -} - func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { fields := []string{} values := []interface{}{} @@ -90,78 +78,46 @@ func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { switch itemT.Kind() { case reflect.Struct: - nfields := itemV.NumField() + + fieldMap := t.Mapper.TypeMap(itemT).Names + nfields := len(fieldMap) values = make([]interface{}, 0, nfields) fields = make([]string, 0, nfields) - for i := 0; i < nfields; i++ { - - field := itemT.Field(i) + for _, fi := range fieldMap { + // log.Println("=>", fi.Name, fi.Options) - if field.PkgPath != `` { - // Field is unexported. + fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) + if fld.Kind() == reflect.Ptr && fld.IsNil() { continue } - // TODO: can we get the placeholder used above somewhere...? - // from the sqlx part..? - - if field.Anonymous { - // It's an anonymous field. Let's skip it unless it has an explicit - // `db` tag. - if field.Tag.Get(`db`) == `` { - continue - } - } - - // Field options. - fieldName, fieldOptions := util.ParseTag(field.Tag.Get(`db`)) - - // Skipping field - if fieldName == `-` { - continue - } - - // Trying to match field name. - - // Still don't have a match? try to match againt JSON. - if fieldName == `` { - fieldName, _ = util.ParseTag(field.Tag.Get(`json`)) - } - - // Nothing works, trying to match by name. - if fieldName == `` { - fieldName = t.columnLike(field.Name) + var value interface{} + if _, ok := fi.Options["stringarray"]; ok { + value = StringArray(fld.Interface().([]string)) + } else if _, ok := fi.Options["int64array"]; ok { + value = Int64Array(fld.Interface().([]int64)) + } else if _, ok := fi.Options["jsonb"]; ok { + value = JsonbType{fld.Interface()} + } else { + value = fld.Interface() } - // Processing tag options. - value := itemV.Field(i).Interface() - - if fieldOptions[`omitempty`] == true { - zero := reflect.Zero(reflect.TypeOf(value)).Interface() - if value == zero { + if _, ok := fi.Options["omitempty"]; ok { + if value == fi.Zero.Interface() { continue } } - if fieldOptions[`inline`] == true { - infields, invalues, inerr := t.FieldValues(value) - if inerr != nil { - return nil, nil, inerr - } - fields = append(fields, infields...) - values = append(values, invalues...) - } else { - fields = append(fields, fieldName) - v, err := marshal(value) - - if err != nil { - return nil, nil, err - } + // TODO: columnLike stuff...? - values = append(values, v) + fields = append(fields, fi.Name) + v, err := marshal(value) + if err != nil { + return nil, nil, err } + values = append(values, v) } case reflect.Map: @@ -172,7 +128,7 @@ func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { for i, keyV := range mkeys { valv := itemV.MapIndex(keyV) - fields[i] = t.columnLike(to.String(keyV.Interface())) + fields[i] = t.columnLike(fmt.Sprintf("%v", keyV.Interface())) v, err := marshal(valv.Interface()) if err != nil { @@ -181,6 +137,7 @@ func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { values[i] = v } + default: return nil, nil, db.ErrExpectingMapOrStruct } @@ -188,6 +145,16 @@ func (t *T) FieldValues(item interface{}) ([]string, []interface{}, error) { return fields, values, nil } +func marshal(v interface{}) (interface{}, error) { + if m, isMarshaler := v.(db.Marshaler); isMarshaler { + var err error + if v, err = m.MarshalDB(); err != nil { + return nil, err + } + } + return v, nil +} + func reset(data interface{}) error { // Resetting element. v := reflect.ValueOf(data).Elem() @@ -197,16 +164,28 @@ func reset(data interface{}) error { return nil } +// normalizeColumn prepares a column for comparison against another column. +func normalizeColumn(s string) string { + return strings.ToLower(reColumnCompareExclude.ReplaceAllString(s, "")) +} + // NewMapper creates a reflectx.Mapper func NewMapper() *reflectx.Mapper { - mapFunc := strings.ToLower + return reflectx.NewMapper("db") +} - tagFunc := func(value string) string { - if strings.Contains(value, ",") { - return strings.Split(value, ",")[0] +// MainTableName returns the name of the first table. +func (t *T) MainTableName() string { + return t.NthTableName(0) +} + +// NthTableName returns the table name at index i. +func (t *T) NthTableName(i int) string { + if len(t.Tables) > i { + chunks := strings.SplitN(t.Tables[i], " ", 2) + if len(chunks) > 0 { + return chunks[0] } - return value } - - return reflectx.NewMapperTagFunc("db", mapFunc, tagFunc) + return "" } diff --git a/postgresql/tx.go b/util/sqlutil/tx/tx.go similarity index 84% rename from postgresql/tx.go rename to util/sqlutil/tx/tx.go index 70b979d74e3763a57941e60e13e261083236778a..533c54058f9f855dbc3004e8fad03f5f3022e0bb 100644 --- a/postgresql/tx.go +++ b/util/sqlutil/tx/tx.go @@ -19,26 +19,28 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package postgresql +package sqltx import ( "github.com/jmoiron/sqlx" ) -type tx struct { - *source - sqlTx *sqlx.Tx - done bool +type Tx struct { + *sqlx.Tx + done bool } -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { +func New(tx *sqlx.Tx) *Tx { + return &Tx{Tx: tx} +} + +func (t *Tx) Done() bool { + return t.done +} + +func (t *Tx) Commit() (err error) { + if err = t.Tx.Commit(); err == nil { t.done = true } return err } - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -}