diff --git a/.travis.yml b/.travis.yml index c363c5cced3b63e91fe6fb6e8a6db4baa23289b6..e9af1942e1c2a687c8e2878f2eec0a61f6d9e6fa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,15 +13,16 @@ go: - 1.4.1 - 1.4.2 -env: GOARCH=amd64 TEST_HOST=127.0.0.1 +env: GOARCH=amd64 TEST_HOST=127.0.0.1 UPPERIO_DB_DEBUG=1 install: - - sudo apt-get install bzr - # - go get github.com/cznic/ql/ql # ql command line util. - # - go install github.com/cznic/ql/ql # ql command line util. + - sudo apt-get install -y bzr make + - go get github.com/cznic/ql/ql # ql command line util. + - go install github.com/cznic/ql/ql # ql command line util. - mkdir ../../../upper.io - ln -s $PWD ../../../upper.io/db - go get -t -d + - cd $GOPATH/src/github.com/jmoiron/sqlx && git checkout ptrs # Peter's branch. # - go get upper.io/db/mongo # - go get upper.io/db/mysql # - go get upper.io/db/postgresql @@ -35,11 +36,24 @@ before_script: - mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root mysql - cat mysql/_dumps/setup.sql | mysql -uroot - cat mysql/_dumps/structs.sql | mysql -uupperio -pupperio upperio_tests + - cat postgresql/_dumps/setup.sql | psql -U postgres - cat postgresql/_dumps/structs.sql | PGPASSWORD="upperio" psql -U upperio upperio_tests + - mongo upperio_tests --eval 'db.addUser("upperio", "upperio")' + + - (cd mysql/_dumps && make) + - (cd postgresql/_dumps && make) + - (cd sqlite/_dumps && make) + - (cd ql/_dumps && make) + # - cat ql/_dumps/structs.sql | $GOPATH/bin/ql -db ql/_dumps/test.db script: - go test upper.io/db/util/sqlgen -test.bench=. - - UPPERIO_DB_DEBUG=1 go test -test.v=1 + - go test upper.io/db/postgresql -test.bench=. + - go test upper.io/db/mysql -test.bench=. + - go test upper.io/db/sqlite -test.bench=. + - go test upper.io/db/ql -test.bench=. + - go test upper.io/db/mongo -test.bench=. + - go test -test.v diff --git a/mongo/database_test.go b/mongo/database_test.go index 1cd42677cd74b063d9b1a740e873031027a00bf1..fe426bb7f2a3edcd4bde08a68e905968dde222ee 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -24,7 +24,6 @@ package mongo import ( "errors" - "flag" "math/rand" "os" "reflect" @@ -52,7 +51,7 @@ var settings = ConnectionURL{ Password: password, } -var host = flag.String("host", "testserver.local", "Testing server address.") +var host string // Structure for testing conversions and datatypes. type testValuesStruct struct { @@ -127,8 +126,11 @@ func init() { time.Second * time.Duration(7331), } - flag.Parse() - settings.Address = db.ParseAddress(*host) + if host = os.Getenv("TEST_HOST"); host == "" { + host = "localhost" + } + + settings.Address = db.ParseAddress(host) } // Enabling outputting some information to stdout, useful for development. @@ -155,7 +157,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: username, Password: password, } @@ -169,7 +171,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: username, Password: "fail", } @@ -181,7 +183,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong database. wrongSettings = db.Settings{ Database: "fail", - Host: *host, + Host: host, User: username, Password: password, } @@ -193,7 +195,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ Database: database, - Host: *host, + Host: host, User: "fail", Password: password, } diff --git a/mysql/_dumps/Makefile b/mysql/_dumps/Makefile index 5e9b4f17c51413eec7ced1a74cea64aa3e70e253..fb1b4e28247a353505402c796fec0018420ad638 100644 --- a/mysql/_dumps/Makefile +++ b/mysql/_dumps/Makefile @@ -1,2 +1,4 @@ +TEST_HOST ?= 127.0.0.1 + all: - cat structs.sql | mysql -uupperio -pupperio upperio_tests -htestserver.local + cat structs.sql | mysql -uupperio -pupperio upperio_tests -h$(TEST_HOST) diff --git a/mysql/collection.go b/mysql/collection.go index fb1ee860a3665618821861c3c362a3f4ed001c90..d510642b72976d0a02435c5469ec4e54e08e6caf 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -23,233 +23,57 @@ package mysql import ( "database/sql" - "fmt" - "reflect" "strings" "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source - names []string -} - -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args + *database } -func interfaceArgs(value interface{}) (args []interface{}) { - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } +var _ = db.Collection(&table{}) - return args +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} - -func (c *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: c, - where: where, - arguments: arguments, - } - - return result -} - -func (c *table) tableN(i int) string { - if len(c.names) > i { - chunks := strings.SplitN(c.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. -func (c *table) Truncate() error { - - _, err := c.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{c.tableN(0)}, +// Truncate deletes all rows from the table. +func (t *table) Truncate() error { + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. -func (c *table) Append(item interface{}) (interface{}, error) { - +// Append inserts an item (map or struct) into the collection. +func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - cols, vals, err := c.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = c.source.getPrimaryKey(c.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -257,14 +81,14 @@ func (c *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{c.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = c.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -291,10 +115,10 @@ func (c *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -317,14 +141,15 @@ func (c *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. -func (c *table) Exists() bool { - if err := c.source.tableExists(c.names...); err != nil { +// Exists returns true if the collection exists. +func (t *table) Exists() bool { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } -func (c *table) Name() string { - return strings.Join(c.names, `, `) +// Name returns the name of the table or tables that form the collection. +func (t *table) Name() string { + return strings.Join(t.Tables, `, `) } diff --git a/mysql/database.go b/mysql/database.go index 97bf5038620bd07e153b01299785bc6017a753e4..df840b2221f2bc1dbf52cf8e0fec1841d3fc8b48 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -23,95 +23,55 @@ package mysql import ( "database/sql" - "os" "strings" "time" - // Importing MySQL driver. - _ "github.com/go-sql-driver/mysql" + _ "github.com/go-sql-driver/mysql" // MySQL driver. "github.com/jmoiron/sqlx" - "upper.io/cache" "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/tx" ) -const ( - // Adapter is the public name of the adapter. - Adapter = `mysql` -) - -var template *sqlgen.Template - var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema } -type columnSchemaT struct { - Name string `db:"column_name"` -} - -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name +type tx struct { + *sqltx.Tx + *database } -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} - -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sqlx.Tx - - if sqlTx, err = s.session.Beginx(); err != nil { - return nil, err - } - - if clone, err = s.clone(); err != nil { - return nil, err - } - - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -// Stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { - s.connURL = connURL - return s.Open() +type columnSchemaT struct { + Name string `db:"column_name"` } -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { var err error // Before db.ConnectionURL we used a unified db.Settings struct. This // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { + if settings, ok := d.connURL.(db.Settings); ok { // User is providing a db.Settings struct, let's translate it into a // ConnectionURL{}. @@ -141,64 +101,71 @@ func (s *source) Open() error { conn.Address = db.HostPort(settings.Host, uint(settings.Port)) } - // Replace original s.connURL - s.connURL = conn + // Replace original d.connURL + d.connURL = conn } - if s.session, err = sqlx.Open(`mysql`, s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`mysql`, d.connURL.String()); err != nil { return err } - s.session.Mapper = sqlutil.NewMapper() + d.session.Mapper = sqlutil.NewMapper() - if err = s.populateSchema(); err != nil { + if err = d.populateSchema(); err != nil { return err } return nil } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL +func (d *database) clone() (*database, error) { + src := &database{} + src.Setup(d.connURL) - if conn, err = ParseURL(s.connURL.String()); err != nil { - return err + if err := src.Open(); err != nil { + return nil, err } - conn.Database = database + return src, nil +} - s.connURL = conn +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() +} - return s.Open() +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { var err error if len(names) == 0 { return nil, db.ErrMissingCollectionName } - if s.tx != nil { - if s.tx.done { + if d.tx != nil { + if d.tx.Done() { return nil, sql.ErrTxDone } } - col := &table{ - source: s, - names: names, - } + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper for _, name := range names { chunks := strings.SplitN(name, ` `, 2) @@ -209,11 +176,11 @@ func (s *source) Collection(names ...string) (db.Collection, error) { tableName := chunks[0] - if err := s.tableExists(tableName); err != nil { + if err := d.tableExists(tableName); err != nil { return nil, err } - if col.Columns, err = s.tableColumns(tableName); err != nil { + if col.Columns, err = d.tableColumns(tableName); err != nil { return nil, err } } @@ -221,49 +188,35 @@ func (s *source) Collection(names ...string) (db.Collection, error) { return col, nil } -// Drops the currently active database. -func (s *source) Drop() error { - - _, err := s.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlDropDatabase, - Database: sqlgen.Database{s.schema.Name}, - }) - - return err -} - -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - tablesInSchema := len(s.schema.Tables) + tablesInSchema := len(d.schema.Tables) // Is schema already populated? if tablesInSchema > 0 { // Pulling table names from schema. - return s.schema.Tables, nil + return d.schema.Tables, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Table: sqlgen.Table{ - `information_schema.tables`, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`table_schema`}, - `=`, - sqlPlaceholder, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Table: sqlgen.TableWithName(`information_schema.tables`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, }, - }, + ), } // Executing statement. var rows *sqlx.Rows - if rows, err = s.doQuery(stmt, s.schema.Name); err != nil { + if rows, err = d.Query(stmt, d.schema.Name); err != nil { return nil, err } @@ -280,7 +233,7 @@ func (s *source) Collections() (collections []string, err error) { } // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.AddTable(name) // Adding table to collections array. collections = append(collections, name) @@ -289,25 +242,65 @@ func (s *source) Collections() (collections []string, err error) { return collections, nil } -func (s *source) clone() (*source, error) { - src := &source{} - src.Setup(s.connURL) +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL - if err := src.Open(); err != nil { - return nil, err + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err } - return src, nil + conn.Database = database + + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) + + return err +} + +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err + } + + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err } + + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -317,25 +310,26 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - res, err = s.session.Exec(query, args...) + res, err = d.session.Exec(query, args...) } return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -345,25 +339,26 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - rows, err = s.session.Queryx(query, args...) + rows, err = d.session.Queryx(query, args...) } return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -373,62 +368,57 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - row = s.session.QueryRowx(query, args...) + row = d.session.QueryRowx(query, args...) } return row, err } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} - -func (s *source) populateSchema() (err error) { +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { var collections []string - s.schema = schema.NewDatabaseSchema() + d.schema = schema.NewDatabaseSchema() // Get database name. stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {sqlgen.Raw{`DATABASE()`}}, - }, + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.RawValue(`DATABASE()`), + ), } var row *sqlx.Row - if row, err = s.doQueryRow(stmt); err != nil { + if row, err = d.QueryRow(stmt); err != nil { return err } - if err = row.Scan(&s.schema.Name); err != nil { + if err = row.Scan(&d.schema.Name); err != nil { return err } // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if collections, err = d.Collections(); err != nil { return err } for i := range collections { // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { + if _, err = d.Collection(collections[i]); err != nil { return err } } @@ -436,31 +426,39 @@ func (s *source) populateSchema() (err error) { return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.tables`}, - Columns: sqlgen.Columns{ - {`table_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.tables`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`table_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -474,32 +472,40 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`information_schema.columns`}, - Columns: sqlgen.Columns{ - {`column_name`}, - {`data_type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`table_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`information_schema.columns`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`column_name`), + sqlgen.ColumnWithName(`data_type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`table_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -509,54 +515,64 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -func (s *source) getPrimaryKey(tableName string) ([]string, error) { +func (d *database) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { return tableSchema.PrimaryKey, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{ - sqlgen.Raw{` + Type: sqlgen.Select, + Table: sqlgen.RawValue(` information_schema.table_constraints AS t JOIN information_schema.key_column_usage k USING(constraint_name, table_schema, table_name) - `}, - }, - Columns: sqlgen.Columns{ - {`k.column_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`t.constraint_type`}, `=`, sqlgen.Value{`primary key`}}, - sqlgen.ColumnValue{sqlgen.Column{`t.table_schema`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`t.table_name`}, `=`, sqlPlaceholder}, - }, - OrderBy: sqlgen.OrderBy{ - sqlgen.SortColumns{ - { - sqlgen.Column{`k.ordinal_position`}, - sqlgen.SqlSortAsc, - }, + `), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`k.column_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.constraint_type`), + Operator: `=`, + Value: sqlgen.NewValue(`primary key`), + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.table_schema`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`t.table_name`), + Operator: `=`, + Value: sqlPlaceholder, }, + ), + OrderBy: &sqlgen.OrderBy{ + SortColumns: sqlgen.JoinSortColumns( + &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(`k.ordinal_position`), + Order: sqlgen.Ascendent, + }, + ), }, } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -572,40 +588,3 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { return tableSchema.PrimaryKey, nil } - -func init() { - - template = &sqlgen.Template{ - mysqlColumnSeparator, - mysqlIdentifierSeparator, - mysqlIdentifierQuote, - mysqlValueSeparator, - mysqlValueQuote, - mysqlAndKeyword, - mysqlOrKeyword, - mysqlNotKeyword, - mysqlDescKeyword, - mysqlAscKeyword, - mysqlDefaultOperator, - mysqlClauseGroup, - mysqlClauseOperator, - mysqlColumnValue, - mysqlTableAliasLayout, - mysqlColumnAliasLayout, - mysqlSortByColumnLayout, - mysqlWhereLayout, - mysqlOrderByLayout, - mysqlInsertLayout, - mysqlSelectLayout, - mysqlUpdateLayout, - mysqlDeleteLayout, - mysqlTruncateLayout, - mysqlDropDatabaseLayout, - mysqlDropTableLayout, - mysqlSelectCountLayout, - mysqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) -} diff --git a/mysql/database_test.go b/mysql/database_test.go index bda28252950b806086046028666c945384924c63..624dcc20f1b7db9c27e54763156153d1eec5bf82 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -40,9 +40,9 @@ import ( ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" ) const ( @@ -50,7 +50,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, Options: map[string]string{ @@ -179,7 +179,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: password, @@ -193,7 +193,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: "fail", @@ -217,7 +217,7 @@ func TestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: "fail", Password: password, @@ -234,7 +234,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, Host: host, @@ -1470,7 +1470,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec("TRUNCATE TABLE `artist`"); err != nil { b.Fatal(err) @@ -1524,7 +1524,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/mysql/mysql.go b/mysql/mysql.go new file mode 100644 index 0000000000000000000000000000000000000000..6cdf2d36b7d8aa56b163581a11b4c10f05d41ec8 --- /dev/null +++ b/mysql/mysql.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package mysql + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `mysql` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/mysql/result.go b/mysql/result.go deleted file mode 100644 index b4fe80c81bece49e71c06c8dea3872f8d2b0dbb9..0000000000000000000000000000000000000000 --- a/mysql/result.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package mysql - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"_t"` -} - -type result struct { - table *table - cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - if err = sqlutil.FetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return nil -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values) - if err != nil { - return err - } - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() (err error) { - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { - var count counter - - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - err = row.Scan(&count.Total) - if err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/mysql/layout.go b/mysql/template.go similarity index 63% rename from mysql/layout.go rename to mysql/template.go index c0ef34060f173ded0bbe07bc95f3c03611f6cf8f..1b2d21dab87f67e0f0cc87ed872d9ba9f37b1112 100644 --- a/mysql/layout.go +++ b/mysql/template.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,37 +22,38 @@ package mysql const ( - mysqlColumnSeparator = `.` - mysqlIdentifierSeparator = `, ` - mysqlIdentifierQuote = "`{{.Raw}}`" - mysqlValueSeparator = `, ` - mysqlValueQuote = `'{{.}}'` - mysqlAndKeyword = `AND` - mysqlOrKeyword = `OR` - mysqlNotKeyword = `NOT` - mysqlDescKeyword = `DESC` - mysqlAscKeyword = `ASC` - mysqlDefaultOperator = `=` - mysqlClauseGroup = `({{.}})` - mysqlClauseOperator = ` {{.}} ` - mysqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - mysqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - mysqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - mysqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - mysqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = "`{{.Value}}`" + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - mysqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - mysqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -79,19 +80,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - mysqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - mysqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - mysqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -106,7 +107,7 @@ const ( {{end}} ` - mysqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -114,23 +115,21 @@ const ( {{.Extra}} ` - mysqlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} ` - mysqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - mysqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - mysqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - mysqlNull = `NULL` ) diff --git a/mysql/tx.go b/mysql/tx.go deleted file mode 100644 index 844f11d019bf7365d5689135a0621802bdddabea..0000000000000000000000000000000000000000 --- a/mysql/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package mysql - -import ( - "github.com/jmoiron/sqlx" -) - -type tx struct { - *source - sqlTx *sqlx.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/postgresql/_dumps/Makefile b/postgresql/_dumps/Makefile index 675048196a990e22c46df5ac45a4eac2ea0c9818..83de0013912ec7f3ef76d69fd8bb4fd45f5b5a54 100644 --- a/postgresql/_dumps/Makefile +++ b/postgresql/_dumps/Makefile @@ -1,2 +1,4 @@ +TEST_HOST ?= 127.0.0.1 + all: - cat structs.sql | PGPASSWORD="upperio" psql -Uupperio upperio_tests -htestserver.local + cat structs.sql | PGPASSWORD="upperio" psql -Uupperio upperio_tests -h$(TEST_HOST) diff --git a/postgresql/collection.go b/postgresql/collection.go index cfcc4b47a5c7c804fee314ea4dc0951966088b88..0158c17dd6f8b646cc5a94e685a46dfbb4e725b0 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -24,190 +24,34 @@ package postgresql import ( "database/sql" "fmt" - "reflect" "strings" "github.com/jmoiron/sqlx" "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source + *database primaryKey string - names []string } -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - for i := range t { - w, v := whereValues(t[i]) - args = append(args, v...) - where.Conditions = append(where.Conditions, w.Conditions...) - } - return - case db.And: - var op sqlgen.And - for i := range t { - k, v := whereValues(t[i]) - args = append(args, v...) - op.Conditions = append(op.Conditions, k.Conditions...) - } - where.Conditions = append(where.Conditions, &op) - return - case db.Or: - var op sqlgen.Or - for i := range t { - w, v := whereValues(t[i]) - args = append(args, v...) - op.Conditions = append(op.Conditions, w.Conditions...) - } - where.Conditions = append(where.Conditions, &op) - return - case db.Raw: - if s, ok := t.Value.(string); ok { - where.Conditions = append(where.Conditions, sqlgen.RawValue(s)) - } - return - case db.Cond: - cv, v := columnValues(t) - args = append(args, v...) - for i := range cv.ColumnValues { - where.Conditions = append(where.Conditions, cv.ColumnValues[i]) - } - return - case db.Constrainer: - cv, v := columnValues(t.Constraint()) - args = append(args, v...) - for i := range cv.ColumnValues { - where.Conditions = append(where.Conditions, cv.ColumnValues[i]) - } - return - } - - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), term)) -} - -func interfaceArgs(value interface{}) (args []interface{}) { - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func columnValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - args = []interface{}{} - - for column, value := range cond { - columnValue := sqlgen.ColumnValue{} - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.ColumnWithName(chunks[0]) - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.RawValue(`()`) - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) - } - - args = append(args, v...) - default: - v := interfaceArgs(value) - - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.RawValue(psqlNull) - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues.ColumnValues = append(columnValues.ColumnValues, &columnValue) - } - - return columnValues, args -} +var _ = db.Collection(&table{}) +// Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: t, - where: where, - arguments: arguments, - } - - return result -} - -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -// Deletes all the rows within the collection. +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - _, err := t.source.doExec(sqlgen.Statement{ + _, err := t.database.Exec(sqlgen.Statement{ Type: sqlgen.Truncate, - Table: sqlgen.TableWithName(t.tableN(0)), + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { @@ -217,46 +61,24 @@ func (t *table) Truncate() error { return nil } -// Appends an item (map or struct) into the collection. +// Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - cols, vals, err := t.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns := new(sqlgen.Columns) - - columns.Columns = make([]sqlgen.Fragment, 0, len(cols)) - for i := range cols { - columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(cols[i])) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - values := new(sqlgen.Values) - var arguments []interface{} - - arguments = make([]interface{}, 0, len(vals)) - values.Values = make([]sqlgen.Fragment, 0, len(vals)) - - for i := range vals { - switch v := vals[i].(type) { - case *sqlgen.Value: - // Adding value. - values.Values = append(values.Values, v) - case sqlgen.Value: - // Adding value. - values.Values = append(values.Values, &v) - default: - // Adding both value and placeholder. - values.Values = append(values.Values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } var pKey []string - if pKey, err = t.source.getPrimaryKey(t.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -265,16 +87,16 @@ func (t *table) Append(item interface{}) (interface{}, error) { stmt := sqlgen.Statement{ Type: sqlgen.Insert, - Table: sqlgen.TableWithName(t.tableN(0)), - Columns: columns, - Values: values, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } // No primary keys defined. if len(pKey) == 0 { var res sql.Result - if res, err = t.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -289,7 +111,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { // A primary key was found. stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) - if rows, err = t.source.doQuery(stmt, arguments...); err != nil { + if rows, err = t.database.Query(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -336,14 +158,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. +// Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) + return strings.Join(t.Tables, `, `) } diff --git a/postgresql/connection.go b/postgresql/connection.go index b1d6ab909f1a42a678e13e50b8084ce53bbfeb3f..451383ce96af9a707eb03d9e354405127da4aeca 100644 --- a/postgresql/connection.go +++ b/postgresql/connection.go @@ -37,8 +37,8 @@ type scanner struct { i int } -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. +// Next returns the next rune. It returns 0, false if the end of the text has +// been reached. func (s *scanner) Next() (rune, bool) { if s.i >= len(s.s) { return 0, false @@ -48,8 +48,8 @@ func (s *scanner) Next() (rune, bool) { return r, true } -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. +// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the +// end of the text has been reached. func (s *scanner) SkipSpaces() (rune, bool) { r, ok := s.Next() for unicode.IsSpace(r) && ok { diff --git a/postgresql/database.go b/postgresql/database.go index 3cfe0931c27bc689de8042b8a03da084745e9530..a95c014c567687ee94704e32e9294f79f1dd93db 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -24,74 +24,58 @@ package postgresql import ( "database/sql" "fmt" - "os" "strconv" "strings" "time" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" // PostgreSQL driver. - "upper.io/cache" "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" -) - -const ( - // Adapter is the public name of the adapter. - Adapter = `postgresql` + "upper.io/db/util/sqlutil/tx" ) var ( - // Query template - template *sqlgen.Template - - // Query statement placeholder sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema } -type columnSchemaT struct { - Name string `db:"column_name"` - DataType string `db:"data_type"` +type tx struct { + *sqltx.Tx + *database } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() - } +type columnSchemaT struct { + Name string `db:"column_name"` + DataType string `db:"data_type"` } -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { var err error // Before db.ConnectionURL we used a unified db.Settings struct. This // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { + if settings, ok := d.connURL.(db.Settings); ok { - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. conn := ConnectionURL{ User: settings.User, Password: settings.Password, @@ -102,30 +86,31 @@ func (s *source) Open() error { }, } - // Replace original s.connURL - s.connURL = conn + d.connURL = conn } - if s.session, err = sqlx.Open(`postgres`, s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`postgres`, d.connURL.String()); err != nil { return err } - s.session.Mapper = sqlutil.NewMapper() + d.session.Mapper = sqlutil.NewMapper() - if err = s.populateSchema(); err != nil { + if err = d.populateSchema(); err != nil { return err } return nil } -func (s *source) Clone() (db.Database, error) { - return s.clone() +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func (s *source) clone() (*source, error) { - src := new(source) - src.Setup(s.connURL) +func (d *database) clone() (*database, error) { + src := new(database) + src.Setup(d.connURL) if err := src.Open(); err != nil { return nil, err @@ -134,39 +119,37 @@ func (s *source) clone() (*source, error) { return src, nil } -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() } return nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { var err error if len(names) == 0 { return nil, db.ErrMissingCollectionName } - if s.tx != nil { - if s.tx.done { + if d.tx != nil { + if d.tx.Done() { return nil, sql.ErrTxDone } } - col := &table{ - source: s, - names: names, - } - col.T.Mapper = s.session.Mapper + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper for _, name := range names { chunks := strings.SplitN(name, ` `, 2) @@ -177,11 +160,11 @@ func (s *source) Collection(names ...string) (db.Collection, error) { tableName := chunks[0] - if err := s.tableExists(tableName); err != nil { + if err := d.tableExists(tableName); err != nil { return nil, err } - if col.Columns, err = s.tableColumns(tableName); err != nil { + if col.Columns, err = d.tableColumns(tableName); err != nil { return nil, err } } @@ -189,16 +172,15 @@ func (s *source) Collection(names ...string) (db.Collection, error) { return col, nil } -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - tablesInSchema := len(s.schema.Tables) + tablesInSchema := len(d.schema.Tables) // Is schema already populated? if tablesInSchema > 0 { // Pulling table names from schema. - return s.schema.Tables, nil + return d.schema.Tables, nil } // Schema is empty. @@ -221,7 +203,7 @@ func (s *source) Collections() (collections []string, err error) { // Executing statement. var rows *sqlx.Rows - if rows, err = s.doQuery(stmt); err != nil { + if rows, err = d.Query(stmt); err != nil { return nil, err } @@ -238,7 +220,7 @@ func (s *source) Collections() (collections []string, err error) { } // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.AddTable(name) // Adding table to collections array. collections = append(collections, name) @@ -247,62 +229,63 @@ func (s *source) Collections() (collections []string, err error) { return collections, nil } -// Changes the active database. -func (s *source) Use(database string) (err error) { +// Use changes the active database. +func (d *database) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } conn.Database = database - s.connURL = conn + d.connURL = conn - return s.Open() + return d.Open() } -// Drops the currently active database. -func (s *source) Drop() error { - _, err := s.doQuery(sqlgen.Statement{ +// Drop removes all tables from the current database. +func (d *database) Drop() error { + _, err := d.Query(sqlgen.Statement{ Type: sqlgen.DropDatabase, - Database: sqlgen.DatabaseWithName(s.schema.Name), + Database: sqlgen.DatabaseWithName(d.schema.Name), }) return err } -// Stores database settings. -func (s *source) Setup(connURL db.ConnectionURL) error { - s.connURL = connURL - return s.Open() +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name } -func (s *source) Transaction() (db.Tx, error) { +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { var err error - var clone *source + var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = s.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = s.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } - tx := &tx{source: clone, sqlTx: sqlTx} + clone.tx = sqltx.New(sqlTx) - clone.tx = tx - - return tx, nil + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -312,30 +295,31 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { - res, err = s.session.Exec(query, args...) + res, err = d.session.Exec(query, args...) } return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -345,30 +329,31 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { - rows, err = s.session.Queryx(query, args...) + rows, err = d.session.Queryx(query, args...) } return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -378,33 +363,35 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, `$`+strconv.Itoa(i+1), 1) } - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { - row = s.session.QueryRowx(query, args...) + row = d.session.QueryRowx(query, args...) } return row, err } -func (s *source) populateSchema() (err error) { +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { var collections []string - s.schema = schema.NewDatabaseSchema() + d.schema = schema.NewDatabaseSchema() // Get database name. stmt := sqlgen.Statement{ @@ -416,22 +403,20 @@ func (s *source) populateSchema() (err error) { var row *sqlx.Row - if row, err = s.doQueryRow(stmt); err != nil { + if row, err = d.QueryRow(stmt); err != nil { return err } - if err = row.Scan(&s.schema.Name); err != nil { + if err = row.Scan(&d.schema.Name); err != nil { return err } - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if collections, err = d.Collections(); err != nil { return err } for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { + if _, err = d.Collection(collections[i]); err != nil { return err } } @@ -439,14 +424,14 @@ func (s *source) populateSchema() (err error) { return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } @@ -471,13 +456,13 @@ func (s *source) tableExists(names ...string) error { ), } - if rows, err = s.doQuery(stmt, s.schema.Name, names[i]); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, names[i]); err != nil { return db.ErrCollectionDoesNotExist } defer rows.Close() - if rows.Next() == false { + if !rows.Next() { return db.ErrCollectionDoesNotExist } } @@ -485,10 +470,10 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil @@ -518,7 +503,7 @@ func (s *source) tableColumns(tableName string) ([]string, error) { var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, s.schema.Name, tableName); err != nil { + if rows, err = d.Query(stmt, d.schema.Name, tableName); err != nil { return nil, err } @@ -530,17 +515,17 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -func (s *source) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) +func (d *database) getPrimaryKey(tableName string) ([]string, error) { + tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { return tableSchema.PrimaryKey, nil @@ -554,7 +539,7 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { sqlgen.ColumnWithName(`pg_attribute.attname`), ), Where: sqlgen.WhereConditions( - sqlgen.RawValue(`pg_class.oid = '`+tableName+`'::regclass`), + sqlgen.RawValue(`pg_class.oid = '"`+tableName+`"'::regclass`), sqlgen.RawValue(`indrelid = pg_class.oid`), sqlgen.RawValue(`pg_attribute.attrelid = pg_class.oid`), sqlgen.RawValue(`pg_attribute.attnum = ANY(pg_index.indkey)`), @@ -573,10 +558,12 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt); err != nil { + if rows, err = d.Query(stmt); err != nil { return nil, err } + defer rows.Close() + tableSchema.PrimaryKey = make([]string, 0, 1) for rows.Next() { @@ -589,39 +576,3 @@ func (s *source) getPrimaryKey(tableName string) ([]string, error) { return tableSchema.PrimaryKey, nil } - -func init() { - template = &sqlgen.Template{ - pgsqlColumnSeparator, - pgsqlIdentifierSeparator, - pgsqlIdentifierQuote, - pgsqlValueSeparator, - pgsqlValueQuote, - pgsqlAndKeyword, - pgsqlOrKeyword, - pgsqlNotKeyword, - pgsqlDescKeyword, - pgsqlAscKeyword, - pgsqlDefaultOperator, - pgsqlClauseGroup, - pgsqlClauseOperator, - pgsqlColumnValue, - pgsqlTableAliasLayout, - pgsqlColumnAliasLayout, - pgsqlSortByColumnLayout, - pgsqlWhereLayout, - pgsqlOrderByLayout, - pgsqlInsertLayout, - pgsqlSelectLayout, - pgsqlUpdateLayout, - pgsqlDeleteLayout, - pgsqlTruncateLayout, - pgsqlDropDatabaseLayout, - pgsqlDropTableLayout, - pgsqlSelectCountLayout, - pgsqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) -} diff --git a/postgresql/database_test.go b/postgresql/database_test.go index c0d522a9c57d21d8b2eee4ddd9c7d99422351346..97e41f4896162ccbd5d5e5c55734b1cb24e70d5a 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -40,9 +40,9 @@ import ( ) const ( - database = "upperio_tests" - username = "upperio" - password = "upperio" + databaseName = "upperio_tests" + username = "upperio" + password = "upperio" ) const ( @@ -50,7 +50,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, User: username, Password: password, Options: map[string]string{ @@ -181,7 +181,7 @@ func SkipTestOpenWithWrongData(t *testing.T) { // Attempt to open with safe settings. rightSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: username, Password: password, @@ -195,9 +195,9 @@ func SkipTestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong password. wrongSettings = db.Settings{ - Database: database, + Database: "fail", Host: host, - User: username, + User: "fail", Password: "fail", } @@ -219,7 +219,7 @@ func SkipTestOpenWithWrongData(t *testing.T) { // Attempt to open with wrong username. wrongSettings = db.Settings{ - Database: database, + Database: databaseName, Host: host, User: "fail", Password: password, @@ -236,7 +236,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, User: username, Password: password, Host: host, diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go new file mode 100644 index 0000000000000000000000000000000000000000..7e8363cc3b1de831ce83948f0ed1baf59fadc51b --- /dev/null +++ b/postgresql/postgresql.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package postgresql + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `postgresql` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/postgresql/layout.go b/postgresql/template.go similarity index 65% rename from postgresql/layout.go rename to postgresql/template.go index 930e6c8bc58fc643b58547d2ad0a119289607d04..aa414ecbb85e1065fa52a2145eb403ac69e6c690 100644 --- a/postgresql/layout.go +++ b/postgresql/template.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,37 +22,38 @@ package postgresql const ( - pgsqlColumnSeparator = `.` - pgsqlIdentifierSeparator = `, ` - pgsqlIdentifierQuote = `"{{.Value}}"` - pgsqlValueSeparator = `, ` - pgsqlValueQuote = `'{{.}}'` - pgsqlAndKeyword = `AND` - pgsqlOrKeyword = `OR` - pgsqlNotKeyword = `NOT` - pgsqlDescKeyword = `DESC` - pgsqlAscKeyword = `ASC` - pgsqlDefaultOperator = `=` - pgsqlClauseGroup = `({{.}})` - pgsqlClauseOperator = ` {{.}} ` - pgsqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - pgsqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - pgsqlSortByColumnLayout = `{{.Column}} {{.Order}}` - - pgsqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - pgsqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - pgsqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -79,19 +80,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - pgsqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - pgsqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - pgsqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -106,7 +107,7 @@ const ( {{end}} ` - pgsqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -114,23 +115,21 @@ const ( {{.Extra}} ` - pgsqlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} RESTART IDENTITY ` - pgsqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - pgsqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - pgsqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - psqlNull = `NULL` ) diff --git a/ql/_dumps/Makefile b/ql/_dumps/Makefile index cba56b0150415fb40b311ce885b8d59a497ff8b8..1703af5222b0476e569a542c0cc26cdefd13489d 100644 --- a/ql/_dumps/Makefile +++ b/ql/_dumps/Makefile @@ -1,3 +1,3 @@ all: rm -f test.db - cat structs.sql | ql -db test.db + cat structs.sql | $$GOPATH/bin/ql -db test.db diff --git a/ql/collection.go b/ql/collection.go index a63b7ad89ade057e6342a0e8e8c6ed591cbf00f2..53a5d4c1b7e98e0237a9a42278ab7973bda83495 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,231 +22,68 @@ package ql import ( - "fmt" + "database/sql" "reflect" "strings" "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `==` - type table struct { sqlutil.T - columnTypes map[string]reflect.Kind - source *source + *database names []string + columnTypes map[string]reflect.Kind } -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args -} +var _ = db.Collection(&table{}) +// Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: t, - where: where, - arguments: arguments, - } - - return result + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - -// Deletes all the rows within the collection. +// Truncate deletes all rows from the table. func (t *table) Truncate() error { - - _, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{t.tableN(0)}, + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. +// Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - cols, vals, err := t.FieldValues(item) - - var columns sqlgen.Columns - var values sqlgen.Values + columnNames, columnValues, err := t.FieldValues(item) - for _, col := range cols { - columns = append(columns, sqlgen.Column{col}) + if err != nil { + return nil, err } - for i := 0; i < len(vals); i++ { - values = append(values, sqlPlaceholder) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - // Error ocurred, stop appending. if err != nil { return nil, err } - res, err := t.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{t.tableN(0)}, - Columns: columns, - Values: values, - }, vals...) + stmt := sqlgen.Statement{ + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, + } - if err != nil { + var res sql.Result + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -263,14 +100,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { return id, nil } -// Returns true if the collection exists. +// Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.source.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) + return strings.Join(t.Tables, `, `) } diff --git a/ql/database.go b/ql/database.go index 1a1c5623b1c1815b373d6cb75f32ffcc6eaa6153..c6529044b60e1aca67f8371c3ae32a918d262eb0 100644 --- a/ql/database.go +++ b/ql/database.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -24,121 +24,250 @@ package ql import ( "database/sql" "fmt" - "os" "strings" "time" _ "github.com/cznic/ql/driver" // QL driver "github.com/jmoiron/sqlx" - "upper.io/cache" "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" -) - -const ( - // Adapter is the public name of the adapter. - Adapter = `ql` + "upper.io/db/util/sqlutil/tx" ) var ( - template *sqlgen.Template - - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema } +type tx struct { + *sqltx.Tx + *database +} + +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) + type columnSchemaT struct { Name string `db:"Name"` } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session +} + +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error + + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + Database: settings.Database, + } + + d.connURL = conn } - return false + + if d.session, err = sqlx.Open(`ql`, d.connURL.String()); err != nil { + return err + } + + d.session.Mapper = sqlutil.NewMapper() + + if err = d.populateSchema(); err != nil { + return err + } + + return nil } -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() +} + +func (d *database) clone() (adapter *database, err error) { + adapter = new(database) + + if err = adapter.Setup(d.connURL); err != nil { + return nil, err } + + return adapter, nil } -func init() { - - template = &sqlgen.Template{ - qlColumnSeparator, - qlIdentifierSeparator, - qlIdentifierQuote, - qlValueSeparator, - qlValueQuote, - qlAndKeyword, - qlOrKeyword, - qlNotKeyword, - qlDescKeyword, - qlAscKeyword, - qlDefaultOperator, - qlClauseGroup, - qlClauseOperator, - qlColumnValue, - qlTableAliasLayout, - qlColumnAliasLayout, - qlSortByColumnLayout, - qlWhereLayout, - qlOrderByLayout, - qlInsertLayout, - qlSelectLayout, - qlUpdateLayout, - qlDeleteLayout, - qlTruncateLayout, - qlDropDatabaseLayout, - qlDropTableLayout, - qlSelectCountLayout, - qlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() } -func (s *source) populateSchema() (err error) { - var collections []string +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() + } + return nil +} + +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error + + if len(names) == 0 { + return nil, db.ErrMissingCollectionName + } + + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } + + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper + + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) + + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } + + tableName := chunks[0] + + if err := d.tableExists(tableName); err != nil { + return nil, err + } + + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } + } - s.schema = schema.NewDatabaseSchema() + return col, nil +} + +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { + + tablesInSchema := len(d.schema.Tables) + + // Is schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil + } + + // Schema is empty. + + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Table`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + ), + } + + // Executing statement. + var rows *sqlx.Rows + if rows, err = d.Query(stmt); err != nil { + return nil, err + } + + defer rows.Close() + + collections = []string{} + + var name string + + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err + } + + // Adding table entry to schema. + d.schema.AddTable(name) + // Adding table to collections array. + collections = append(collections, name) + } + + return collections, nil +} + +// Use changes the active database. +func (d *database) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } - s.schema.Name = conn.Database + conn.Database = database - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { - return err + d.connURL = conn + + return d.Open() +} + +// Drop removes all tables from the current database. +func (d *database) Drop() error { + return db.ErrUnsupported +} + +// Setup stores database settings. +func (d *database) Setup(conn db.ConnectionURL) error { + d.connURL = conn + return d.Open() +} + +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name +} + +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { + var err error + var clone *database + var sqlTx *sqlx.Tx + + if clone, err = d.clone(); err != nil { + return nil, err } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err - } + if sqlTx, err = clone.session.Beginx(); err != nil { + return nil, err } - return err + clone.tx = sqltx.New(sqlTx) + + return tx{Tx: clone.tx, database: clone}, nil } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -148,26 +277,26 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) + if d.tx != nil { + res, err = d.tx.Exec(query, args...) } else { var tx *sqlx.Tx - if tx, err = s.session.Beginx(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -183,7 +312,8 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return res, err } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { var rows *sqlx.Rows var query string var err error @@ -193,26 +323,26 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) } else { var tx *sqlx.Tx - if tx, err = s.session.Beginx(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -228,7 +358,8 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return rows, err } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { var query string var row *sqlx.Row var err error @@ -238,26 +369,26 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() - if s.session == nil { + if d.session == nil { return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) } else { var tx *sqlx.Tx - if tx, err = s.session.Beginx(); err != nil { + if tx, err = d.session.Beginx(); err != nil { return nil, err } @@ -273,194 +404,64 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return row, err } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} - -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() -} - -func (s *source) clone() (adapter *source, err error) { - adapter = new(source) - - if err = adapter.Setup(s.connURL); err != nil { - return nil, err - } - - return adapter, nil -} - -func (s *source) Clone() (db.Database, error) { - return s.clone() -} - -func (s *source) Transaction() (db.Tx, error) { - var err error - var clone *source - var sqlTx *sqlx.Tx - - if clone, err = s.clone(); err != nil { - return nil, err - } - - if sqlTx, err = s.session.Beginx(); err != nil { - return nil, err - } - - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(conn db.ConnectionURL) error { - s.connURL = conn - return s.Open() -} - -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session -} - -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { - var err error - - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - Database: settings.Database, - } - - s.connURL = conn - } - - if s.session, err = sqlx.Open(`ql`, s.connURL.String()); err != nil { - return err - } - - s.session.Mapper = sqlutil.NewMapper() - - if err = s.populateSchema(); err != nil { - return err - } +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - return nil -} + d.schema = schema.NewDatabaseSchema() -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} - -// Changes the active database. -func (s *source) Use(database string) (err error) { var conn ConnectionURL - if conn, err = ParseURL(s.connURL.String()); err != nil { + if conn, err = ParseURL(d.connURL.String()); err != nil { return err } - conn.Database = database - - s.connURL = conn - - return s.Open() -} - -// Drops the currently active database. -func (s *source) Drop() error { - return db.ErrUnsupported -} - -// Returns a list of all tables within the currently active database. -func (s *source) Collections() (collections []string, err error) { - - tablesInSchema := len(s.schema.Tables) + d.schema.Name = conn.Database - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil - } - - // Schema is empty. - - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Table`}, - Columns: sqlgen.Columns{ - {`Name`}, - }, - } - - // Executing statement. - var rows *sqlx.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { + return err } - defer rows.Close() - - collections = []string{} - - var name string - - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err } - - // Adding table entry to schema. - s.schema.AddTable(name) - - // Adding table to collections array. - collections = append(collections, name) } - return collections, nil + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Table`}, - Columns: sqlgen.Columns{ - {`Name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`Name`}, `==`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Table`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`Name`), + Operator: `==`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, names[i]); err != nil { + if rows, err = d.Query(stmt, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -474,31 +475,35 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil } stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`__Column`}, - Columns: sqlgen.Columns{ - {`Name`}, - {`Type`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`TableName`}, `==`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`__Column`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`Name`), + sqlgen.ColumnWithName(`Type`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`TableName`), + Operator: `==`, + Value: sqlPlaceholder, + }, + ), } var rows *sqlx.Rows var err error - if rows, err = s.doQuery(stmt, tableName); err != nil { + if rows, err = d.Query(stmt, tableName); err != nil { return nil, err } @@ -508,51 +513,11 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(tableFields)) for i := range tableFields { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, tableFields[i].Name) + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, tableFields[i].Name) } - return s.schema.TableInfo[tableName].Columns, nil -} - -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } - - col := &table{ - source: s, - names: names, - } - - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) - - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName - } - - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } - - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err - } - } - - return col, nil + return d.schema.TableInfo[tableName].Columns, nil } diff --git a/ql/database_test.go b/ql/database_test.go index dd389d34df9da070abd3dde6eee1700c7fd280e6..43957d73fc4f7e13ea2fb1424f0fa1a33ff2832d 100644 --- a/ql/database_test.go +++ b/ql/database_test.go @@ -46,7 +46,7 @@ import ( ) const ( - database = `_dumps/test.db` + databaseName = `_dumps/test.db` ) const ( @@ -54,7 +54,7 @@ const ( ) var settings = db.Settings{ - Database: database, + Database: databaseName, } // Structure for testing conversions and datatypes. @@ -155,7 +155,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, } // Opening database. @@ -942,6 +942,7 @@ func TestRawQuery(t *testing.T) { var rows *sqlx.Rows var err error var drv *sqlx.DB + var tx *sqlx.Tx type publicationType struct { ID int64 `db:"id,omitempty"` @@ -957,7 +958,11 @@ func TestRawQuery(t *testing.T) { drv = sess.Driver().(*sqlx.DB) - rows, err = drv.Queryx(` + if tx, err = drv.Beginx(); err != nil { + t.Fatal(err) + } + + if rows, err = tx.Queryx(` SELECT p.id AS id, p.title AS publication_title, @@ -967,9 +972,11 @@ func TestRawQuery(t *testing.T) { (SELECT id() AS id, title, author_id FROM publication) AS p WHERE a.id == p.author_id - `) + `); err != nil { + t.Fatal(err) + } - if err != nil { + if err = tx.Commit(); err != nil { t.Fatal(err) } diff --git a/ql/ql.go b/ql/ql.go new file mode 100644 index 0000000000000000000000000000000000000000..acffe02cff797d9a84fcb0e853d938a68dd08d51 --- /dev/null +++ b/ql/ql.go @@ -0,0 +1,72 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package ql + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `ql` + +var template *sqlutil.TemplateWithUtils + +func init() { + + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/ql/result.go b/ql/result.go deleted file mode 100644 index 60985da8abcbd586424ff99aa8544b32decfdf74..0000000000000000000000000000000000000000 --- a/ql/result.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"total"` -} - -type result struct { - table *table - cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - if err = sqlutil.FetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return nil -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values) - if err != nil { - return err - } - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() (err error) { - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { - var count counter - - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - err = row.Scan(&count.Total) - if err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/ql/layout.go b/ql/template.go similarity index 65% rename from ql/layout.go rename to ql/template.go index d60f32e2d49b1bb55ba2f4dcae6dae061e31c609..61fc47c044905a107393e34539f62e3f19e1cebc 100644 --- a/ql/layout.go +++ b/ql/template.go @@ -22,37 +22,38 @@ package ql const ( - qlColumnSeparator = `.` - qlIdentifierSeparator = `, ` - qlIdentifierQuote = `{{.Raw}}` - qlValueSeparator = `, ` - qlValueQuote = `"{{.}}"` - qlAndKeyword = `&&` - qlOrKeyword = `||` - qlNotKeyword = `!=` - qlDescKeyword = `DESC` - qlAscKeyword = `ASC` - qlDefaultOperator = `==` - qlClauseGroup = `({{.}})` - qlClauseOperator = ` {{.}} ` - qlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - qlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - qlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - qlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - qlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `{{.Value}}` + adapterValueSeparator = `, ` + adapterValueQuote = `"{{.}}"` + adapterAndKeyword = `&&` + adapterOrKeyword = `||` + adapterNotKeyword = `!=` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `==` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - qlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - qlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -77,19 +78,19 @@ const ( OFFSET {{.Offset}} {{end}} ` - qlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - qlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - qlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT count(1) AS total FROM {{.Table}} @@ -104,7 +105,7 @@ const ( {{end}} ` - qlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -112,19 +113,19 @@ const ( {{.Extra}} ` - qlTruncateLayout = ` + adapterTruncateLayout = ` TRUNCATE TABLE {{.Table}} ` - qlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - qlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - qlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} diff --git a/ql/tx.go b/ql/tx.go deleted file mode 100644 index 6093f5aa25f5ae6e53a327dc30a06606f8a7e84e..0000000000000000000000000000000000000000 --- a/ql/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package ql - -import ( - "github.com/jmoiron/sqlx" -) - -type tx struct { - *source - sqlTx *sqlx.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/sqlite/collection.go b/sqlite/collection.go index c65e64d4a1790f5f4e0337c268399d8e5c9d5f2d..0aadfcd010ce8ae64d1e4ce1bc6321b3d69bd9c4 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -22,8 +22,6 @@ package sqlite import ( - "fmt" - "reflect" "strings" "database/sql" @@ -31,228 +29,52 @@ import ( "upper.io/db" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/result" ) -const defaultOperator = `=` - type table struct { sqlutil.T - source *source - names []string -} - -func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { - - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - where = make(sqlgen.Where, 0, l) - for _, cond := range t { - w, v := whereValues(cond) - args = append(args, v...) - where = append(where, w...) - } - case db.And: - and := make(sqlgen.And, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - and = append(and, k...) - } - where = append(where, and) - case db.Or: - or := make(sqlgen.Or, 0, len(t)) - for _, cond := range t { - k, v := whereValues(cond) - args = append(args, v...) - or = append(or, k...) - } - where = append(where, or) - case db.Raw: - if s, ok := t.Value.(string); ok == true { - where = append(where, sqlgen.Raw{s}) - } - case db.Cond: - k, v := conditionValues(t) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - case db.Constrainer: - k, v := conditionValues(t.Constraint()) - args = append(args, v...) - for _, kk := range k { - where = append(where, kk) - } - default: - panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), reflect.TypeOf(t))) - } - - return where, args -} - -func interfaceArgs(value interface{}) (args []interface{}) { - if value == nil { - return nil - } - - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: - var i, total int - - total = v.Len() - if total > 0 { - args = make([]interface{}, total) - - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - - return args - } - return nil - default: - args = []interface{}{value} - } - - return args -} - -func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []interface{}) { - - args = []interface{}{} - - for column, value := range cond { - var columnValue sqlgen.ColumnValue - - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) - - columnValue.Column = sqlgen.Column{chunks[0]} - - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } else { - columnValue.Operator = defaultOperator - } - - switch value := value.(type) { - case db.Func: - // Catches functions. - v := interfaceArgs(value.Args) - columnValue.Operator = value.Name - - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`()`}} - } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } - - args = append(args, v...) - default: - // Catches everything else. - v := interfaceArgs(value) - l := len(v) - if v == nil || l == 0 { - // Nil value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{`NULL`}} - } else { - if l > 1 { - // Array value given. - columnValue.Value = sqlgen.Value{sqlgen.Raw{fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))}} - } else { - // Single value given. - columnValue.Value = sqlPlaceholder - } - args = append(args, v...) - } - } - - columnValues = append(columnValues, columnValue) - } - - return columnValues, args + *database } -func (c *table) Find(terms ...interface{}) db.Result { - where, arguments := whereValues(terms) - - result := &result{ - table: c, - where: where, - arguments: arguments, - } +var _ = db.Collection(&table{}) - return result -} - -func (c *table) tableN(i int) string { - if len(c.names) > i { - chunks := strings.SplitN(c.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" +// Find creates a result set with the given conditions. +func (t *table) Find(terms ...interface{}) db.Result { + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } -// Deletes all the rows within the collection. -func (c *table) Truncate() error { - - _, err := c.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlTruncate, - Table: sqlgen.Table{c.tableN(0)}, +// Truncate deletes all rows from the table. +func (t *table) Truncate() error { + _, err := t.database.Exec(sqlgen.Statement{ + Type: sqlgen.Truncate, + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { return err } - return nil } -// Appends an item (map or struct) into the collection. -func (c *table) Append(item interface{}) (interface{}, error) { - +// Append inserts an item (map or struct) into the collection. +func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - var columns sqlgen.Columns - var values sqlgen.Values - var arguments []interface{} - cols, vals, err := c.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) - // Error ocurred, stop appending. if err != nil { return nil, err } - columns = make(sqlgen.Columns, 0, len(cols)) - for i := range cols { - columns = append(columns, sqlgen.Column{cols[i]}) - } + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) - arguments = make([]interface{}, 0, len(vals)) - values = make(sqlgen.Values, 0, len(vals)) - for i := range vals { - switch v := vals[i].(type) { - case sqlgen.Value: - // Adding value. - values = append(values, v) - default: - // Adding both value and placeholder. - values = append(values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = c.source.getPrimaryKey(c.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -260,14 +82,14 @@ func (c *table) Append(item interface{}) (interface{}, error) { } stmt := sqlgen.Statement{ - Type: sqlgen.SqlInsert, - Table: sqlgen.Table{c.tableN(0)}, - Columns: columns, - Values: values, + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = c.source.doExec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -294,10 +116,10 @@ func (c *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -320,14 +142,15 @@ func (c *table) Append(item interface{}) (interface{}, error) { return keyMap, nil } -// Returns true if the collection exists. -func (c *table) Exists() bool { - if err := c.source.tableExists(c.names...); err != nil { +// Exists returns true if the collection exists. +func (t *table) Exists() bool { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } -func (c *table) Name() string { - return strings.Join(c.names, `, `) +// Name returns the name of the table or tables that form the collection. +func (t *table) Name() string { + return strings.Join(t.Tables, `, `) } diff --git a/sqlite/database.go b/sqlite/database.go index d0998b154dd6ef506609a7c82fc057bb7433c94a..22de80bb25b0fbb068dfe720649e18137d903a95 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -24,35 +24,26 @@ package sqlite import ( "database/sql" "fmt" - "os" "strings" "time" - // Importing SQLite3 driver. "github.com/jmoiron/sqlx" - _ "github.com/mattn/go-sqlite3" - "upper.io/cache" + _ "github.com/mattn/go-sqlite3" // SQLite3 driver. "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" "upper.io/db/util/sqlutil" + "upper.io/db/util/sqlutil/tx" ) -const ( - // Adapter is the public name of the adapter. - Adapter = `sqlite` -) - -var template *sqlgen.Template - var ( - sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} + sqlPlaceholder = sqlgen.RawValue(`?`) ) -type source struct { +type database struct { connURL db.ConnectionURL session *sqlx.DB - tx *tx + tx *sqltx.Tx schema *schema.DatabaseSchema // columns property was introduced so we could query PRAGMA data only once // and retrieve all the column information we'd need, such as name and if it @@ -60,399 +51,392 @@ type source struct { columns map[string][]columnSchemaT } -type columnSchemaT struct { - Name string `db:"name"` - PK int `db:"pk"` +type tx struct { + *sqltx.Tx + *database } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} +var ( + _ = db.Database(&database{}) + _ = db.Tx(&tx{}) +) -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() - } +type columnSchemaT struct { + Name string `db:"name"` + PK int `db:"pk"` } -func init() { - - template = &sqlgen.Template{ - sqlColumnSeparator, - sqlIdentifierSeparator, - sqlIdentifierQuote, - sqlValueSeparator, - sqlValueQuote, - sqlAndKeyword, - sqlOrKeyword, - sqlNotKeyword, - sqlDescKeyword, - sqlAscKeyword, - sqlDefaultOperator, - sqlClauseGroup, - sqlClauseOperator, - sqlColumnValue, - sqlTableAliasLayout, - sqlColumnAliasLayout, - sqlSortByColumnLayout, - sqlWhereLayout, - sqlOrderByLayout, - sqlInsertLayout, - sqlSelectLayout, - sqlUpdateLayout, - sqlDeleteLayout, - sqlTruncateLayout, - sqlDropDatabaseLayout, - sqlDropTableLayout, - sqlSelectCountLayout, - sqlGroupByLayout, - cache.NewCache(), - } - - db.Register(Adapter, &source{}) +// Driver returns the underlying *sqlx.DB instance. +func (d *database) Driver() interface{} { + return d.session } -func (s *source) populateSchema() (err error) { - var collections []string +// Open attempts to connect to the database server using already stored settings. +func (d *database) Open() error { + var err error - s.schema = schema.NewDatabaseSchema() + // Before db.ConnectionURL we used a unified db.Settings struct. This + // condition checks for that type and provides backwards compatibility. + if settings, ok := d.connURL.(db.Settings); ok { + // User is providing a db.Settings struct, let's translate it into a + // ConnectionURL{}. + conn := ConnectionURL{ + Database: settings.Database, + Options: map[string]string{ + "cache": "shared", + }, + } - var conn ConnectionURL + d.connURL = conn + } - if conn, err = ParseURL(s.connURL.String()); err != nil { + if d.session, err = sqlx.Open(`sqlite3`, d.connURL.String()); err != nil { return err } - s.schema.Name = conn.Database + d.session.Mapper = sqlutil.NewMapper() - // The Collections() call will populate schema if its nil. - if collections, err = s.Collections(); err != nil { + if err = d.populateSchema(); err != nil { return err } - for i := range collections { - // Populate each collection. - if _, err = s.Collection(collections[i]); err != nil { - return err - } - } + return nil +} - return err +// Clone returns a cloned db.Database session, this is typically used for +// transactions. +func (d *database) Clone() (db.Database, error) { + return d.clone() } -func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { - var query string - var res sql.Result - var err error - var start, end int64 +func (d *database) clone() (*database, error) { + src := &database{} + src.Setup(d.connURL) - start = time.Now().UnixNano() + if err := src.Open(); err != nil { + return nil, err + } - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + return src, nil +} - if s.session == nil { - return nil, db.ErrNotConnected +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (d *database) Ping() error { + return d.session.Ping() +} + +// Close terminates the current database session. +func (d *database) Close() error { + if d.session != nil { + return d.session.Close() } + return nil +} - query = stmt.Compile(template) +// Collection returns a table by name. +func (d *database) Collection(names ...string) (db.Collection, error) { + var err error - if s.tx != nil { - res, err = s.tx.sqlTx.Exec(query, args...) - } else { - res, err = s.session.Exec(query, args...) + if len(names) == 0 { + return nil, db.ErrMissingCollectionName } - return res, err -} + if d.tx != nil { + if d.tx.Done() { + return nil, sql.ErrTxDone + } + } -func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { - var rows *sqlx.Rows - var query string - var err error - var start, end int64 + col := &table{database: d} + col.T.Tables = names + col.T.Mapper = d.session.Mapper - start = time.Now().UnixNano() + for _, name := range names { + chunks := strings.SplitN(name, ` `, 2) - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + if len(chunks) == 0 { + return nil, db.ErrMissingCollectionName + } - if s.session == nil { - return nil, db.ErrNotConnected - } + tableName := chunks[0] - query = stmt.Compile(template) + if err := d.tableExists(tableName); err != nil { + return nil, err + } - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) - } else { - rows, err = s.session.Queryx(query, args...) + if col.Columns, err = d.tableColumns(tableName); err != nil { + return nil, err + } } - return rows, err + return col, nil } -func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { - var query string - var row *sqlx.Row - var err error - var start, end int64 +// Collections returns a list of non-system tables from the database. +func (d *database) Collections() (collections []string, err error) { - start = time.Now().UnixNano() + tablesInSchema := len(d.schema.Tables) - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() - - if s.session == nil { - return nil, db.ErrNotConnected + // Id.schema already populated? + if tablesInSchema > 0 { + // Pulling table names from schema. + return d.schema.Tables, nil } - query = stmt.Compile(template) + // Schema is empty. - if s.tx != nil { - row = s.tx.sqlTx.QueryRowx(query, args...) - } else { - row = s.session.QueryRowx(query, args...) + // Querying table names. + stmt := sqlgen.Statement{ + Type: sqlgen.Select, + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`tbl_name`), + ), + Table: sqlgen.TableWithName(`sqlite_master`), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`type`), + Operator: `=`, + Value: sqlgen.NewValue(`table`), + }, + ), } - return row, err -} - -func (s *source) doRawQuery(query string, args ...interface{}) (*sqlx.Rows, error) { + // Executing statement. var rows *sqlx.Rows - var err error - var start, end int64 + if rows, err = d.Query(stmt); err != nil { + return nil, err + } - start = time.Now().UnixNano() + defer rows.Close() - defer func() { - end = time.Now().UnixNano() - debugLog(query, args, err, start, end) - }() + collections = []string{} - if s.session == nil { - return nil, db.ErrNotConnected - } + var name string - if s.tx != nil { - rows, err = s.tx.sqlTx.Queryx(query, args...) - } else { - rows, err = s.session.Queryx(query, args...) + for rows.Next() { + // Getting table name. + if err = rows.Scan(&name); err != nil { + return nil, err + } + + // Adding table entry to schema. + d.schema.AddTable(name) + + // Adding table to collections array. + collections = append(collections, name) } - return rows, err + return collections, nil } -// Returns the string name of the database. -func (s *source) Name() string { - return s.schema.Name -} +// Use changes the active database. +func (d *database) Use(database string) (err error) { + var conn ConnectionURL + + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err + } + + conn.Database = database + + d.connURL = conn -// Ping verifies a connection to the database is still alive, -// establishing a connection if necessary. -func (s *source) Ping() error { - return s.session.Ping() + return d.Open() } -func (s *source) clone() (*source, error) { - src := &source{} - src.Setup(s.connURL) +// Drop removes all tables from the current database. +func (d *database) Drop() error { - if err := src.Open(); err != nil { - return nil, err - } + _, err := d.Query(sqlgen.Statement{ + Type: sqlgen.DropDatabase, + Database: sqlgen.DatabaseWithName(d.schema.Name), + }) - return src, nil + return err +} + +// Setup stores database settings. +func (d *database) Setup(connURL db.ConnectionURL) error { + d.connURL = connURL + return d.Open() } -func (s *source) Clone() (db.Database, error) { - return s.clone() +// Name returns the name of the database. +func (d *database) Name() string { + return d.schema.Name } -func (s *source) Transaction() (db.Tx, error) { +// Transaction starts a transaction block and returns a db.Tx struct that can +// be used to issue transactional queries. +func (d *database) Transaction() (db.Tx, error) { var err error - var clone *source + var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = s.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = s.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } - tx := &tx{source: clone, sqlTx: sqlTx} - - clone.tx = tx - - return tx, nil -} - -// Stores database settings. -func (s *source) Setup(conn db.ConnectionURL) error { - s.connURL = conn - return s.Open() -} + clone.tx = sqltx.New(sqlTx) -// Returns the underlying *sqlx.DB instance. -func (s *source) Driver() interface{} { - return s.session + return tx{Tx: clone.tx, database: clone}, nil } -// Attempts to connect to a database using the stored settings. -func (s *source) Open() error { +// Exec compiles and executes a statement that does not return any rows. +func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { + var query string + var res sql.Result var err error + var start, end int64 - // Before db.ConnectionURL we used a unified db.Settings struct. This - // condition checks for that type and provides backwards compatibility. - if settings, ok := s.connURL.(db.Settings); ok { - // User is providing a db.Settings struct, let's translate it into a - // ConnectionURL{}. - conn := ConnectionURL{ - Database: settings.Database, - Options: map[string]string{ - "cache": "shared", - }, - } + start = time.Now().UnixNano() - s.connURL = conn - } + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - if s.session, err = sqlx.Open(`sqlite3`, s.connURL.String()); err != nil { - return err + if d.session == nil { + return nil, db.ErrNotConnected } - s.session.Mapper = sqlutil.NewMapper() + query = stmt.Compile(template.Template) - if err = s.populateSchema(); err != nil { - return err + if d.tx != nil { + res, err = d.tx.Exec(query, args...) + } else { + res, err = d.session.Exec(query, args...) } - return nil + return res, err } -// Closes the current database session. -func (s *source) Close() error { - if s.session != nil { - return s.session.Close() - } - return nil -} +// Query compiles and executes a statement that returns rows. +func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows + var query string + var err error + var start, end int64 -// Changes the active database. -func (s *source) Use(database string) (err error) { - var conn ConnectionURL + start = time.Now().UnixNano() - if conn, err = ParseURL(s.connURL.String()); err != nil { - return err + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() + + if d.session == nil { + return nil, db.ErrNotConnected } - conn.Database = database + query = stmt.Compile(template.Template) - s.connURL = conn + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) + } else { + rows, err = d.session.Queryx(query, args...) + } - return s.Open() + return rows, err } -// Drops the currently active database. -func (s *source) Drop() error { - return db.ErrUnsupported -} +// QueryRow compiles and executes a statement that returns at most one row. +func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) { + var query string + var row *sqlx.Row + var err error + var start, end int64 -// Collections() Returns a list of non-system tables/collections contained -// within the currently active database. -func (s *source) Collections() (collections []string, err error) { + start = time.Now().UnixNano() - tablesInSchema := len(s.schema.Tables) + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - // Is schema already populated? - if tablesInSchema > 0 { - // Pulling table names from schema. - return s.schema.Tables, nil + if d.session == nil { + return nil, db.ErrNotConnected } - // Schema is empty. + query = stmt.Compile(template.Template) - // Querying table names. - stmt := sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Columns: sqlgen.Columns{ - {`tbl_name`}, - }, - Table: sqlgen.Table{`sqlite_master`}, - Where: sqlgen.Where{ - sqlgen.ColumnValue{ - sqlgen.Column{`type`}, - `=`, - sqlgen.Value{`table`}, - }, - }, + if d.tx != nil { + row = d.tx.QueryRowx(query, args...) + } else { + row = d.session.QueryRowx(query, args...) } - // Executing statement. - var rows *sqlx.Rows - if rows, err = s.doQuery(stmt); err != nil { - return nil, err - } + return row, err +} - defer rows.Close() +// populateSchema looks up for the table info in the database and populates its +// schema for internal use. +func (d *database) populateSchema() (err error) { + var collections []string - collections = []string{} + d.schema = schema.NewDatabaseSchema() - var name string + var conn ConnectionURL - for rows.Next() { - // Getting table name. - if err = rows.Scan(&name); err != nil { - return nil, err - } + if conn, err = ParseURL(d.connURL.String()); err != nil { + return err + } - // Adding table entry to schema. - s.schema.AddTable(name) + d.schema.Name = conn.Database - // Adding table to collections array. - collections = append(collections, name) + // The Collections() call will populate schema if its nil. + if collections, err = d.Collections(); err != nil { + return err } - return collections, nil + for i := range collections { + // Populate each collection. + if _, err = d.Collection(collections[i]); err != nil { + return err + } + } + + return err } -func (s *source) tableExists(names ...string) error { +func (d *database) tableExists(names ...string) error { var stmt sqlgen.Statement var err error var rows *sqlx.Rows for i := range names { - if s.schema.HasTable(names[i]) { + if d.schema.HasTable(names[i]) { // We already know this table exists. continue } stmt = sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{`sqlite_master`}, - Columns: sqlgen.Columns{ - {`tbl_name`}, - }, - Where: sqlgen.Where{ - sqlgen.ColumnValue{sqlgen.Column{`type`}, `=`, sqlPlaceholder}, - sqlgen.ColumnValue{sqlgen.Column{`tbl_name`}, `=`, sqlPlaceholder}, - }, + Type: sqlgen.Select, + Table: sqlgen.TableWithName(`sqlite_master`), + Columns: sqlgen.JoinColumns( + sqlgen.ColumnWithName(`tbl_name`), + ), + Where: sqlgen.WhereConditions( + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`type`), + Operator: `=`, + Value: sqlPlaceholder, + }, + &sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(`tbl_name`), + Operator: `=`, + Value: sqlPlaceholder, + }, + ), } - if rows, err = s.doQuery(stmt, `table`, names[i]); err != nil { + if rows, err = d.Query(stmt, `table`, names[i]); err != nil { return db.ErrCollectionDoesNotExist } @@ -466,10 +450,10 @@ func (s *source) tableExists(names ...string) error { return nil } -func (s *source) tableColumns(tableName string) ([]string, error) { +func (d *database) tableColumns(tableName string) ([]string, error) { // Making sure this table is allocated. - tableSchema := s.schema.Table(tableName) + tableSchema := d.schema.Table(tableName) if len(tableSchema.Columns) > 0 { return tableSchema.Columns, nil @@ -477,10 +461,10 @@ func (s *source) tableColumns(tableName string) ([]string, error) { q := fmt.Sprintf(`PRAGMA TABLE_INFO('%s')`, tableName) - rows, err := s.doRawQuery(q) + rows, err := d.doRawQuery(q) - if s.columns == nil { - s.columns = make(map[string][]columnSchemaT) + if d.columns == nil { + d.columns = make(map[string][]columnSchemaT) } columns := []columnSchemaT{} @@ -489,81 +473,64 @@ func (s *source) tableColumns(tableName string) ([]string, error) { return nil, err } - s.columns[tableName] = columns + d.columns[tableName] = columns - s.schema.TableInfo[tableName].Columns = make([]string, 0, len(s.columns)) + d.schema.TableInfo[tableName].Columns = make([]string, 0, len(d.columns)) - for i := range s.columns[tableName] { - s.schema.TableInfo[tableName].Columns = append(s.schema.TableInfo[tableName].Columns, s.columns[tableName][i].Name) + for i := range d.columns[tableName] { + d.schema.TableInfo[tableName].Columns = append(d.schema.TableInfo[tableName].Columns, d.columns[tableName][i].Name) } - return s.schema.TableInfo[tableName].Columns, nil + return d.schema.TableInfo[tableName].Columns, nil } -// Returns a collection instance by name. -func (s *source) Collection(names ...string) (db.Collection, error) { - var err error - - if len(names) == 0 { - return nil, db.ErrMissingCollectionName - } - - if s.tx != nil { - if s.tx.done { - return nil, sql.ErrTxDone - } - } +func (d *database) getPrimaryKey(tableName string) ([]string, error) { + tableSchema := d.schema.Table(tableName) - col := &table{ - source: s, - names: names, - } + d.tableColumns(tableName) - for _, name := range names { - chunks := strings.SplitN(name, ` `, 2) + maxValue := -1 - if len(chunks) == 0 { - return nil, db.ErrMissingCollectionName + for i := range d.columns[tableName] { + if d.columns[tableName][i].PK > 0 && d.columns[tableName][i].PK > maxValue { + maxValue = d.columns[tableName][i].PK } + } - tableName := chunks[0] - - if err := s.tableExists(tableName); err != nil { - return nil, err - } + if maxValue > 0 { + tableSchema.PrimaryKey = make([]string, maxValue) - if col.Columns, err = s.tableColumns(tableName); err != nil { - return nil, err + for i := range d.columns[tableName] { + if d.columns[tableName][i].PK > 0 { + tableSchema.PrimaryKey[d.columns[tableName][i].PK-1] = d.columns[tableName][i].Name + } } } - return col, nil + return tableSchema.PrimaryKey, nil } -// getPrimaryKey returns the names of the columns that define the primary key -// of the table. -func (s *source) getPrimaryKey(tableName string) ([]string, error) { - tableSchema := s.schema.Table(tableName) +func (d *database) doRawQuery(query string, args ...interface{}) (*sqlx.Rows, error) { + var rows *sqlx.Rows + var err error + var start, end int64 - s.tableColumns(tableName) + start = time.Now().UnixNano() - maxValue := -1 + defer func() { + end = time.Now().UnixNano() + sqlutil.Log(query, args, err, start, end) + }() - for i := range s.columns[tableName] { - if s.columns[tableName][i].PK > 0 && s.columns[tableName][i].PK > maxValue { - maxValue = s.columns[tableName][i].PK - } + if d.session == nil { + return nil, db.ErrNotConnected } - if maxValue > 0 { - tableSchema.PrimaryKey = make([]string, maxValue) - - for i := range s.columns[tableName] { - if s.columns[tableName][i].PK > 0 { - tableSchema.PrimaryKey[s.columns[tableName][i].PK-1] = s.columns[tableName][i].Name - } - } + if d.tx != nil { + rows, err = d.tx.Queryx(query, args...) + } else { + rows, err = d.session.Queryx(query, args...) } - return tableSchema.PrimaryKey, nil + return rows, err } diff --git a/sqlite/database_test.go b/sqlite/database_test.go index 4480e392fb2145268def9cd0fdf15f5e245b573c..da54e350e4550ed95064f733db6e3c27a28ac2d6 100644 --- a/sqlite/database_test.go +++ b/sqlite/database_test.go @@ -47,7 +47,7 @@ import ( ) const ( - database = `_dumps/gotest.sqlite3.db` + databaseName = `_dumps/gotest.sqlite3.db` ) const ( @@ -55,7 +55,7 @@ const ( ) var settings = ConnectionURL{ - Database: database, + Database: databaseName, } // Structure for testing conversions and datatypes. @@ -167,7 +167,7 @@ func TestOldSettings(t *testing.T) { var sess db.Database oldSettings := db.Settings{ - Database: database, + Database: databaseName, } // Opening database. @@ -1378,10 +1378,24 @@ func TestDataTypes(t *testing.T) { loc, _ := time.LoadLocation(testTimeZone) item.Date = item.Date.In(loc) + // TODO: Try to guess this conversion. + if item.DateP.Location() != testValues.DateP.Location() { + v := item.DateP.In(testValues.DateP.Location()) + item.DateP = &v + } + // The original value and the test subject must match. - if reflect.DeepEqual(item, testValues) == false { - fmt.Printf("item1: %v\n", item) - fmt.Printf("test2: %v\n", testValues) + if !reflect.DeepEqual(item, testValues) { + fmt.Printf("item1: %#v\n", item) + fmt.Printf("test2: %#v\n", testValues) + fmt.Printf("item1: %#v\n", item.Date.String()) + fmt.Printf("test2: %#v\n", testValues.Date.String()) + fmt.Printf("item1: %v\n", item.Date.Location().String()) + fmt.Printf("test2: %v\n", testValues.Date.Location().String()) + fmt.Printf("item1: %#v\n", item.DateP) + fmt.Printf("test2: %#v\n", testValues.DateP) + fmt.Printf("item1: %v\n", item.DateP.Location().String()) + fmt.Printf("test2: %v\n", testValues.DateP.Location().String()) t.Fatalf("Struct is different.") } } @@ -1402,7 +1416,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec(`DELETE FROM "artist"`); err != nil { b.Fatal(err) @@ -1456,7 +1470,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/sqlite/result.go b/sqlite/result.go deleted file mode 100644 index 6504200c81894b9a1659678fc25db5b969c9b3f7..0000000000000000000000000000000000000000 --- a/sqlite/result.go +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package sqlite - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "upper.io/db" - "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil" -) - -type counter struct { - Total uint64 `db:"_t"` -} - -type result struct { - table *table - cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. - limit sqlgen.Limit - offset sqlgen.Offset - columns sqlgen.Columns - where sqlgen.Where - orderBy sqlgen.OrderBy - groupBy sqlgen.GroupBy - arguments []interface{} -} - -// Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { - var err error - // We need a cursor, if the cursor does not exists yet then we create one. - if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ - Type: sqlgen.SqlSelect, - Table: sqlgen.Table{r.table.Name()}, - Columns: r.columns, - Limit: r.limit, - Offset: r.offset, - Where: r.where, - OrderBy: r.orderBy, - GroupBy: r.groupBy, - }, r.arguments...) - } - return err -} - -// Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) - return r -} - -// Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { - r.limit = sqlgen.Limit(n) - return r -} - -// Determines how many documents will be skipped before starting to grab -// results. -func (r *result) Skip(n uint) db.Result { - r.offset = sqlgen.Offset(n) - return r -} - -// Used to group results that have the same value in the same column or -// columns. -func (r *result) Group(fields ...interface{}) db.Result { - - groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) - - l := len(fields) - - for i := 0; i < l; i++ { - switch value := fields[i].(type) { - // Maybe other types? - default: - groupByColumns = append(groupByColumns, sqlgen.Column{value}) - } - } - - r.groupBy = groupByColumns - - return r -} - -// Determines sorting of results according to the provided names. Fields may be -// prefixed by - (minus) which means descending order, ascending order would be -// used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { - - sortColumns := make(sqlgen.SortColumns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var sort sqlgen.SortColumn - - switch value := fields[i].(type) { - case db.Raw: - sort = sqlgen.SortColumn{ - sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}, - sqlgen.SqlSortAsc, - } - case string: - if strings.HasPrefix(value, `-`) { - // Explicit descending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value[1:]}, - sqlgen.SqlSortDesc, - } - } else { - // Ascending order. - sort = sqlgen.SortColumn{ - sqlgen.Column{value}, - sqlgen.SqlSortAsc, - } - } - } - sortColumns = append(sortColumns, sort) - } - - r.orderBy.SortColumns = sortColumns - - return r -} - -// Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { - - r.columns = make(sqlgen.Columns, 0, len(fields)) - - l := len(fields) - for i := 0; i < l; i++ { - var col sqlgen.Column - switch value := fields[i].(type) { - case db.Func: - v := interfaceArgs(value.Args) - var s string - if len(v) == 0 { - s = fmt.Sprintf(`%s()`, value.Name) - } else { - ss := make([]string, 0, len(v)) - for j := range v { - ss = append(ss, fmt.Sprintf(`%v`, v[j])) - } - s = fmt.Sprintf(`%s(%s)`, value.Name, strings.Join(ss, `, `)) - } - col = sqlgen.Column{sqlgen.Raw{s}} - case db.Raw: - col = sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}} - default: - col = sqlgen.Column{value} - } - r.columns = append(r.columns, col) - } - - return r -} - -// Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - // Current cursor. - err = r.setCursor() - - if err != nil { - return err - } - - defer r.Close() - - // Fetching all results within the cursor. - err = sqlutil.FetchRows(r.cursor, dst) - - return err -} - -// Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { - var err error - - if r.cursor != nil { - return db.ErrQueryIsPending - } - - defer r.Close() - - err = r.Next(dst) - - return err -} - -// Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { - - if err = r.setCursor(); err != nil { - r.Close() - return err - } - - if err = sqlutil.FetchRow(r.cursor, dst); err != nil { - r.Close() - return err - } - - return nil -} - -// Removes the matching items from the collection. -func (r *result) Remove() error { - var err error - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlDelete, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - return err - -} - -// Updates matching items from the collection with values of the given map or -// struct. -func (r *result) Update(values interface{}) error { - - ff, vv, err := r.table.FieldValues(values) - if err != nil { - return err - } - - total := len(ff) - - cvs := make(sqlgen.ColumnValues, 0, total) - - for i := 0; i < total; i++ { - cvs = append(cvs, sqlgen.ColumnValue{sqlgen.Column{ff[i]}, "=", sqlPlaceholder}) - } - - vv = append(vv, r.arguments...) - - _, err = r.table.source.doExec(sqlgen.Statement{ - Type: sqlgen.SqlUpdate, - Table: sqlgen.Table{r.table.Name()}, - ColumnValues: cvs, - Where: r.where, - }, vv...) - - return err -} - -// Closes the result set. -func (r *result) Close() (err error) { - if r.cursor != nil { - err = r.cursor.Close() - r.cursor = nil - } - return err -} - -// Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { - var count counter - - row, err := r.table.source.doQueryRow(sqlgen.Statement{ - Type: sqlgen.SqlSelectCount, - Table: sqlgen.Table{r.table.Name()}, - Where: r.where, - }, r.arguments...) - - if err != nil { - return 0, err - } - - err = row.Scan(&count.Total) - if err != nil { - return 0, err - } - - return count.Total, nil -} diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go new file mode 100644 index 0000000000000000000000000000000000000000..ea66cafdedeb1a49e361414cda975acdfb47c827 --- /dev/null +++ b/sqlite/sqlite.go @@ -0,0 +1,71 @@ +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sqlite + +import ( + "upper.io/cache" + "upper.io/db" + "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +// Adapter is the public name of the adapter. +const Adapter = `sqlite` + +var template *sqlutil.TemplateWithUtils + +func init() { + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + NotKeyword: adapterNotKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + }) + + db.Register(Adapter, &database{}) +} diff --git a/sqlite/layout.go b/sqlite/template.go similarity index 65% rename from sqlite/layout.go rename to sqlite/template.go index 9d937696e41b9247a817b61f8507774850b62424..abea16614305df659ce23b2b33a84df02a01da90 100644 --- a/sqlite/layout.go +++ b/sqlite/template.go @@ -22,37 +22,38 @@ package sqlite const ( - sqlColumnSeparator = `.` - sqlIdentifierSeparator = `, ` - sqlIdentifierQuote = `"{{.Raw}}"` - sqlValueSeparator = `, ` - sqlValueQuote = `'{{.}}'` - sqlAndKeyword = `AND` - sqlOrKeyword = `OR` - sqlNotKeyword = `NOT` - sqlDescKeyword = `DESC` - sqlAscKeyword = `ASC` - sqlDefaultOperator = `=` - sqlClauseGroup = `({{.}})` - sqlClauseOperator = ` {{.}} ` - sqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` - sqlTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - sqlColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` - sqlSortByColumnLayout = `{{.Column}} {{.Sort}}` - - sqlOrderByLayout = ` + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterNotKeyword = `NOT` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` {{if .SortColumns}} ORDER BY {{.SortColumns}} {{end}} ` - sqlWhereLayout = ` + adapterWhereLayout = ` {{if .Conds}} WHERE {{.Conds}} {{end}} ` - sqlSelectLayout = ` + adapterSelectLayout = ` SELECT {{if .Columns}} @@ -75,24 +76,24 @@ const ( {{if .Offset}} {{if not .Limit}} - LIMIT -1 + LIMIT -1 {{end}} OFFSET {{.Offset}} {{end}} ` - sqlDeleteLayout = ` + adapterDeleteLayout = ` DELETE FROM {{.Table}} {{.Where}} ` - sqlUpdateLayout = ` + adapterUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} {{ .Where }} ` - sqlSelectCountLayout = ` + adapterSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} @@ -104,13 +105,13 @@ const ( {{if .Offset}} {{if not .Limit}} - LIMIT -1 + LIMIT -1 {{end}} OFFSET {{.Offset}} {{end}} ` - sqlInsertLayout = ` + adapterInsertLayout = ` INSERT INTO {{.Table}} ({{.Columns}}) VALUES @@ -118,23 +119,21 @@ const ( {{.Extra}} ` - sqlTruncateLayout = ` + adapterTruncateLayout = ` DELETE FROM {{.Table}} ` - sqlDropDatabaseLayout = ` + adapterDropDatabaseLayout = ` DROP DATABASE {{.Database}} ` - sqlDropTableLayout = ` + adapterDropTableLayout = ` DROP TABLE {{.Table}} ` - sqlGroupByLayout = ` + adapterGroupByLayout = ` {{if .GroupColumns}} GROUP BY {{.GroupColumns}} {{end}} ` - - sqlNull = `NULL` ) diff --git a/sqlite/tx.go b/sqlite/tx.go deleted file mode 100644 index dbfc7698441d78012ba56494a4054149da254179..0000000000000000000000000000000000000000 --- a/sqlite/tx.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package sqlite - -import ( - "github.com/jmoiron/sqlx" -) - -type tx struct { - *source - sqlTx *sqlx.Tx - done bool -} - -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { - t.done = true - } - return err -} - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -} diff --git a/util/sqlgen/template.go b/util/sqlgen/template.go index 78663f54ba94edddf66c2f59b38b07d986c13e0f..bea487f39f26fb4061b9aeef3c81bd2973f7fba5 100644 --- a/util/sqlgen/template.go +++ b/util/sqlgen/template.go @@ -48,6 +48,7 @@ type Template struct { DescKeyword string AscKeyword string DefaultOperator string + AssignmentOperator string ClauseGroup string ClauseOperator string ColumnValue string diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go new file mode 100644 index 0000000000000000000000000000000000000000..d7bc224c8e75bf632986ed8ea87513644dc238e0 --- /dev/null +++ b/util/sqlutil/convert.go @@ -0,0 +1,201 @@ +package sqlutil + +import ( + "fmt" + "reflect" + "strings" + "upper.io/db" + "upper.io/db/util/sqlgen" +) + +var ( + sqlPlaceholder = sqlgen.RawValue(`?`) + sqlNull = sqlgen.RawValue(`NULL`) +) + +type TemplateWithUtils struct { + *sqlgen.Template +} + +func NewTemplateWithUtils(template *sqlgen.Template) *TemplateWithUtils { + return &TemplateWithUtils{template} +} + +// ToWhereWithArguments converts the given db.Cond parameters into a sqlgen.Where +// value. +func (tu *TemplateWithUtils) ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + for i := range t { + w, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case db.And: + var op sqlgen.And + for i := range t { + k, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + op.Conditions = append(op.Conditions, k.Conditions...) + } + where.Conditions = append(where.Conditions, &op) + return + case db.Or: + var op sqlgen.Or + for i := range t { + w, v := tu.ToWhereWithArguments(t[i]) + args = append(args, v...) + op.Conditions = append(op.Conditions, w.Conditions...) + } + where.Conditions = append(where.Conditions, &op) + return + case db.Raw: + if s, ok := t.Value.(string); ok { + where.Conditions = append(where.Conditions, sqlgen.RawValue(s)) + } + return + case db.Cond: + cv, v := tu.ToColumnValues(t) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + return + case db.Constrainer: + cv, v := tu.ToColumnValues(t.Constraint()) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + return + } + + panic(fmt.Sprintf(db.ErrUnknownConditionType.Error(), term)) +} + +// ToInterfaceArguments converts the given value into an array of interfaces. +func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []interface{}) { + if value == nil { + return nil + } + + v := reflect.ValueOf(value) + + switch v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() + } + + return args + } + return nil + default: + args = []interface{}{value} + } + + return args +} + +// ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct. +func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { + + args = []interface{}{} + + for column, value := range cond { + columnValue := sqlgen.ColumnValue{} + + // Guessing operator from input, or using a default one. + column := strings.TrimSpace(column) + chunks := strings.SplitN(column, ` `, 2) + + columnValue.Column = sqlgen.ColumnWithName(chunks[0]) + + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } else { + columnValue.Operator = tu.DefaultOperator + } + + switch value := value.(type) { + case db.Func: + v := tu.ToInterfaceArguments(value.Args) + columnValue.Operator = value.Name + + if v == nil { + // A function with no arguments. + columnValue.Value = sqlgen.RawValue(`()`) + } else { + // A function with one or more arguments. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } + + args = append(args, v...) + default: + v := tu.ToInterfaceArguments(value) + + l := len(v) + if v == nil || l == 0 { + // Nil value given. + columnValue.Value = sqlNull + } else { + if l > 1 { + // Array value given. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } else { + // Single value given. + columnValue.Value = sqlPlaceholder + } + args = append(args, v...) + } + } + + ToColumnValues.ColumnValues = append(ToColumnValues.ColumnValues, &columnValue) + } + + return ToColumnValues, args +} + +// ToColumnsValuesAndArguments maps the given columnNames and columnValues into +// sqlgen's Columns and Values, it also extracts and returns query arguments. +func (tu *TemplateWithUtils) ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(sqlgen.Columns) + + columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i])) + } + + values := new(sqlgen.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]sqlgen.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *sqlgen.Value: + // Adding value. + values.Values = append(values.Values, v) + case sqlgen.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} diff --git a/util/sqlutil/debug.go b/util/sqlutil/debug.go index 08d8ebe9f9f03bad1ec16b428edcdf515f80e423..f82a98577de0c1eef0661aedbda9050c62e31413 100644 --- a/util/sqlutil/debug.go +++ b/util/sqlutil/debug.go @@ -24,7 +24,10 @@ package sqlutil import ( "fmt" "log" + "os" "strings" + + "upper.io/db" ) // Debug is used for printing SQL queries and arguments. @@ -59,3 +62,17 @@ func (d *Debug) Print() { log.Printf("\n\t%s\n\n", strings.Join(s, "\n\t")) } + +func IsDebugEnabled() bool { + if os.Getenv(db.EnvEnableDebug) != "" { + return true + } + return false +} + +func Log(query string, args []interface{}, err error, start int64, end int64) { + if IsDebugEnabled() { + d := Debug{query, args, err, start, end} + d.Print() + } +} diff --git a/postgresql/result.go b/util/sqlutil/result/result.go similarity index 79% rename from postgresql/result.go rename to util/sqlutil/result/result.go index b1d145b09dde3f6ce05a968f44f90b059f1100fe..7badb0e750e2774ff0f511bb1018cccd8b624813 100644 --- a/postgresql/result.go +++ b/util/sqlutil/result/result.go @@ -19,7 +19,7 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package postgresql +package result import ( "fmt" @@ -31,12 +31,16 @@ import ( "upper.io/db/util/sqlutil" ) +var ( + sqlPlaceholder = sqlgen.RawValue(`?`) +) + type counter struct { Total uint64 `db:"_t"` } -type result struct { - table *table +type Result struct { + table DataProvider cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. limit sqlgen.Limit offset sqlgen.Offset @@ -45,14 +49,26 @@ type result struct { orderBy sqlgen.OrderBy groupBy sqlgen.GroupBy arguments []interface{} + template *sqlutil.TemplateWithUtils +} + +// NewResult creates and results a new result set on the given table, this set +// is limited by the given sqlgen.Where conditions. +func NewResult(template *sqlutil.TemplateWithUtils, p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { + return &Result{ + table: p, + where: where, + arguments: arguments, + template: template, + } } // Executes a SELECT statement that can feed Next(), All() or One(). -func (r *result) setCursor() error { +func (r *Result) setCursor() error { var err error // We need a cursor, if the cursor does not exists yet then we create one. if r.cursor == nil { - r.cursor, err = r.table.source.doQuery(sqlgen.Statement{ + r.cursor, err = r.table.Query(sqlgen.Statement{ Type: sqlgen.Select, Table: sqlgen.TableWithName(r.table.Name()), Columns: &r.columns, @@ -67,27 +83,27 @@ func (r *result) setCursor() error { } // Sets conditions for reducing the working set. -func (r *result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = whereValues(terms) +func (r *Result) Where(terms ...interface{}) db.Result { + r.where, r.arguments = r.template.ToWhereWithArguments(terms) return r } // Determines the maximum limit of results to be returned. -func (r *result) Limit(n uint) db.Result { +func (r *Result) Limit(n uint) db.Result { r.limit = sqlgen.Limit(n) return r } // Determines how many documents will be skipped before starting to grab // results. -func (r *result) Skip(n uint) db.Result { +func (r *Result) Skip(n uint) db.Result { r.offset = sqlgen.Offset(n) return r } // Used to group results that have the same value in the same column or // columns. -func (r *result) Group(fields ...interface{}) db.Result { +func (r *Result) Group(fields ...interface{}) db.Result { var columns []sqlgen.Fragment for i := range fields { @@ -107,7 +123,7 @@ func (r *result) Group(fields ...interface{}) db.Result { // Determines sorting of results according to the provided names. Fields may be // prefixed by - (minus) which means descending order, ascending order would be // used otherwise. -func (r *result) Sort(fields ...interface{}) db.Result { +func (r *Result) Sort(fields ...interface{}) db.Result { var sortColumns sqlgen.SortColumns @@ -144,7 +160,7 @@ func (r *result) Sort(fields ...interface{}) db.Result { } // Retrieves only the given fields. -func (r *result) Select(fields ...interface{}) db.Result { +func (r *Result) Select(fields ...interface{}) db.Result { r.columns = sqlgen.Columns{} @@ -152,7 +168,7 @@ func (r *result) Select(fields ...interface{}) db.Result { var col sqlgen.Fragment switch value := fields[i].(type) { case db.Func: - v := interfaceArgs(value.Args) + v := r.template.ToInterfaceArguments(value.Args) var s string if len(v) == 0 { s = fmt.Sprintf(`%s()`, value.Name) @@ -176,7 +192,7 @@ func (r *result) Select(fields ...interface{}) db.Result { } // Dumps all results into a pointer to an slice of structs or maps. -func (r *result) All(dst interface{}) error { +func (r *Result) All(dst interface{}) error { var err error if r.cursor != nil { @@ -199,7 +215,7 @@ func (r *result) All(dst interface{}) error { } // Fetches only one result from the resultset. -func (r *result) One(dst interface{}) error { +func (r *Result) One(dst interface{}) error { var err error if r.cursor != nil { @@ -214,7 +230,7 @@ func (r *result) One(dst interface{}) error { } // Fetches the next result from the resultset. -func (r *result) Next(dst interface{}) (err error) { +func (r *Result) Next(dst interface{}) (err error) { if err = r.setCursor(); err != nil { r.Close() @@ -230,10 +246,10 @@ func (r *result) Next(dst interface{}) (err error) { } // Removes the matching items from the collection. -func (r *result) Remove() error { +func (r *Result) Remove() error { var err error - _, err = r.table.source.doExec(sqlgen.Statement{ + _, err = r.table.Exec(sqlgen.Statement{ Type: sqlgen.Delete, Table: sqlgen.TableWithName(r.table.Name()), Where: &r.where, @@ -245,7 +261,7 @@ func (r *result) Remove() error { // Updates matching items from the collection with values of the given map or // struct. -func (r *result) Update(values interface{}) error { +func (r *Result) Update(values interface{}) error { ff, vv, err := r.table.FieldValues(values) if err != nil { @@ -255,12 +271,12 @@ func (r *result) Update(values interface{}) error { cvs := new(sqlgen.ColumnValues) for i := range ff { - cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: "=", Value: sqlPlaceholder}) + cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: r.template.AssignmentOperator, Value: sqlPlaceholder}) } vv = append(vv, r.arguments...) - _, err = r.table.source.doExec(sqlgen.Statement{ + _, err = r.table.Exec(sqlgen.Statement{ Type: sqlgen.Update, Table: sqlgen.TableWithName(r.table.Name()), ColumnValues: cvs, @@ -271,7 +287,7 @@ func (r *result) Update(values interface{}) error { } // Closes the result set. -func (r *result) Close() (err error) { +func (r *Result) Close() (err error) { if r.cursor != nil { err = r.cursor.Close() r.cursor = nil @@ -280,10 +296,10 @@ func (r *result) Close() (err error) { } // Counts the elements within the main conditions of the set. -func (r *result) Count() (uint64, error) { +func (r *Result) Count() (uint64, error) { var count counter - row, err := r.table.source.doQueryRow(sqlgen.Statement{ + row, err := r.table.QueryRow(sqlgen.Statement{ Type: sqlgen.Count, Table: sqlgen.TableWithName(r.table.Name()), Where: &r.where, diff --git a/util/sqlutil/result/table.go b/util/sqlutil/result/table.go new file mode 100644 index 0000000000000000000000000000000000000000..57ec42ac1bdf8abe90ea0667e7b8177089820c8d --- /dev/null +++ b/util/sqlutil/result/table.go @@ -0,0 +1,15 @@ +package result + +import ( + "database/sql" + "github.com/jmoiron/sqlx" + "upper.io/db/util/sqlgen" +) + +type DataProvider interface { + Name() string + Query(sqlgen.Statement, ...interface{}) (*sqlx.Rows, error) + QueryRow(sqlgen.Statement, ...interface{}) (*sqlx.Row, error) + Exec(sqlgen.Statement, ...interface{}) (sql.Result, error) + FieldValues(interface{}) ([]string, []interface{}, error) +} diff --git a/util/sqlutil/sqlutil.go b/util/sqlutil/sqlutil.go index 93427815168f626d965b3990c9b71fee3b0d21fc..4d2b54df5b3074dafa2e02561e149061c340be1e 100644 --- a/util/sqlutil/sqlutil.go +++ b/util/sqlutil/sqlutil.go @@ -49,6 +49,7 @@ var ( type T struct { Columns []string Mapper *reflectx.Mapper + Tables []string // Holds table names. } func (t *T) columnLike(s string) string { @@ -172,3 +173,19 @@ func normalizeColumn(s string) string { func NewMapper() *reflectx.Mapper { return reflectx.NewMapper("db") } + +// MainTableName returns the name of the first table. +func (t *T) MainTableName() string { + return t.NthTableName(0) +} + +// NthTableName returns the table name at index i. +func (t *T) NthTableName(i int) string { + if len(t.Tables) > i { + chunks := strings.SplitN(t.Tables[i], " ", 2) + if len(chunks) > 0 { + return chunks[0] + } + } + return "" +} diff --git a/postgresql/tx.go b/util/sqlutil/tx/tx.go similarity index 84% rename from postgresql/tx.go rename to util/sqlutil/tx/tx.go index 70b979d74e3763a57941e60e13e261083236778a..533c54058f9f855dbc3004e8fad03f5f3022e0bb 100644 --- a/postgresql/tx.go +++ b/util/sqlutil/tx/tx.go @@ -19,26 +19,28 @@ // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -package postgresql +package sqltx import ( "github.com/jmoiron/sqlx" ) -type tx struct { - *source - sqlTx *sqlx.Tx - done bool +type Tx struct { + *sqlx.Tx + done bool } -func (t *tx) Commit() (err error) { - err = t.sqlTx.Commit() - if err == nil { +func New(tx *sqlx.Tx) *Tx { + return &Tx{Tx: tx} +} + +func (t *Tx) Done() bool { + return t.done +} + +func (t *Tx) Commit() (err error) { + if err = t.Tx.Commit(); err == nil { t.done = true } return err } - -func (t *tx) Rollback() error { - return t.sqlTx.Rollback() -}