diff --git a/config.go b/config.go index f0e34e8249e83f3c79ea5a6dc0f7f5e8e8f5e27c..2c21554ea31d6f4de22daa4dad0ecdf6740b7b2e 100644 --- a/config.go +++ b/config.go @@ -37,10 +37,18 @@ type Settings interface { SetLogger(Logger) // Returns the currently configured logger. Logger() Logger + + // SetPreparedStatementCache enables or disables the prepared statement + // cache. + SetPreparedStatementCache(bool) + // PreparedStatementCacheEnabled returns true if the prepared statement cache + // is enabled, false otherwise. + PreparedStatementCacheEnabled() bool } type conf struct { - loggingEnabled uint32 + loggingEnabled uint32 + preparedStatementCacheEnabled uint32 queryLogger Logger queryLoggerMu sync.RWMutex @@ -65,20 +73,38 @@ func (c *conf) SetLogger(lg Logger) { c.queryLogger = lg } -func (c *conf) SetLogging(value bool) { +func (c *conf) binaryOption(opt *uint32) bool { + if atomic.LoadUint32(opt) == 1 { + return true + } + return false +} + +func (c *conf) setBinaryOption(opt *uint32, value bool) { if value { - atomic.StoreUint32(&c.loggingEnabled, 1) + atomic.StoreUint32(opt, 1) return } - atomic.StoreUint32(&c.loggingEnabled, 0) + atomic.StoreUint32(opt, 0) +} + +func (c *conf) SetLogging(value bool) { + c.setBinaryOption(&c.loggingEnabled, value) } func (c *conf) LoggingEnabled() bool { - if v := atomic.LoadUint32(&c.loggingEnabled); v == 1 { - return true - } - return false + return c.binaryOption(&c.loggingEnabled) +} + +func (c *conf) SetPreparedStatementCache(value bool) { + c.setBinaryOption(&c.preparedStatementCacheEnabled, value) +} + +func (c *conf) PreparedStatementCacheEnabled() bool { + return c.binaryOption(&c.preparedStatementCacheEnabled) } // Conf provides global configuration settings for upper-db. -var Conf Settings = &conf{} +var Conf Settings = &conf{ + preparedStatementCacheEnabled: 0, +} diff --git a/db.go b/db.go index 84c39fc967029a01de81880cf23dd5f9420b6f4a..bd6650509809555750705ac0d2014604d95e231c 100644 --- a/db.go +++ b/db.go @@ -489,6 +489,18 @@ type Database interface { // ClearCache clears all the cache mechanisms the adapter is using. ClearCache() + + // SetConnMaxLifetime sets the maximum amount of time a connection may be + // reused. + SetConnMaxLifetime(time.Duration) + + // SetMaxIdleConns sets the maximum number of connections in the idle + // connection pool. + SetMaxIdleConns(int) + + // SetMaxOpenConns sets the maximum number of open connections to the + // database. + SetMaxOpenConns(int) } // Tx has methods for transactions that can be either committed or rolled back. diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index 7c1fa4d0c33fb8c2bf1ed909d07562ab333f95bf..6c63a2f50dcaaa798af2a8c6efdb25231cc5b01e 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -27,7 +27,7 @@ type HasCleanUp interface { // HasStatementExec allows the adapter to have its own exec statement. type HasStatementExec interface { - StatementExec(stmt *sql.Stmt, args ...interface{}) (sql.Result, error) + StatementExec(query string, args ...interface{}) (sql.Result, error) } // Database represents a SQL database. @@ -49,7 +49,7 @@ type PartialDatabase interface { FindTablePrimaryKeys(name string) ([]string, error) NewLocalCollection(name string) db.Collection - CompileStatement(stmt *exql.Statement) (query string) + CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) ConnectionURL() db.ConnectionURL Err(in error) (out error) @@ -73,6 +73,12 @@ type BaseDatabase interface { BindTx(*sql.Tx) error Transaction() BaseTx + + SetConnMaxLifetime(time.Duration) + SetMaxIdleConns(int) + SetMaxOpenConns(int) + + BindClone(PartialDatabase) (BaseDatabase, error) } // NewBaseDatabase provides a BaseDatabase given a PartialDatabase @@ -98,6 +104,8 @@ type database struct { sess *sql.DB sessMu sync.Mutex + psMu sync.Mutex + sessID uint64 txID uint64 @@ -178,6 +186,30 @@ func (d *database) Ping() error { return nil } +// SetConnMaxLifetime sets the maximum amount of time a connection may be +// reused. +func (d *database) SetConnMaxLifetime(t time.Duration) { + if sess := d.Session(); sess != nil { + sess.SetConnMaxLifetime(t) + } +} + +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +func (d *database) SetMaxIdleConns(n int) { + if sess := d.Session(); sess != nil { + sess.SetMaxIdleConns(n) + } +} + +// SetMaxOpenConns sets the maximum number of open connections to the +// database. +func (d *database) SetMaxOpenConns(n int) { + if sess := d.Session(); sess != nil { + sess.SetMaxOpenConns(n) + } +} + // ClearCache removes all caches. func (d *database) ClearCache() { d.collectionMu.Lock() @@ -189,6 +221,20 @@ func (d *database) ClearCache() { } } +// BindClone binds a clone that is linked to the current +// session. This is commonly done before creating a transaction +// session. +func (d *database) BindClone(p PartialDatabase) (BaseDatabase, error) { + nd := NewBaseDatabase(p).(*database) + nd.name = d.name + nd.sess = d.sess + if err := nd.Ping(); err != nil { + return nil, err + } + nd.sessID = newSessionID() + return nd, nil +} + // Close terminates the current database session func (d *database) Close() error { defer func() { @@ -201,6 +247,7 @@ func (d *database) Close() error { if cleaner, ok := d.PartialDatabase.(HasCleanUp); ok { cleaner.CleanUp() } + d.cachedCollections.Clear() d.cachedStatements.Clear() // Closes prepared statements as well. @@ -212,6 +259,7 @@ func (d *database) Close() error { if !tx.Committed() { tx.Rollback() + return nil } } return nil @@ -267,18 +315,32 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res }(time.Now()) } - var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { - return nil, err + if execer, ok := d.PartialDatabase.(HasStatementExec); ok { + query, args = d.compileStatement(stmt, args) + res, err = execer.StatementExec(query, args...) + return } - defer p.Close() - if execer, ok := d.PartialDatabase.(HasStatementExec); ok { - res, err = execer.StatementExec(p.Stmt, args...) + tx := d.Transaction() + + if db.Conf.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = d.prepareStatement(stmt, args); err != nil { + return nil, err + } + defer p.Close() + + res, err = p.Exec(args...) + return + } + + query, args = d.compileStatement(stmt, args) + if tx != nil { + res, err = tx.(*sqlTx).Exec(query, args...) return } - res, err = p.Exec(args...) + res, err = d.sess.Exec(query, args...) return } @@ -300,14 +362,28 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro }(time.Now()) } - var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { - return nil, err + tx := d.Transaction() + + if db.Conf.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = d.prepareStatement(stmt, args); err != nil { + return nil, err + } + defer p.Close() + + rows, err = p.Query(args...) + return } - defer p.Close() - rows, err = p.Query(args...) + query, args = d.compileStatement(stmt, args) + if tx != nil { + rows, err = tx.(*sqlTx).Query(query, args...) + return + } + + rows, err = d.sess.Query(query, args...) return + } // StatementQueryRow compiles and executes a statement that returns at most one @@ -329,13 +405,26 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) }(time.Now()) } - var p *Stmt - if p, query, err = d.prepareStatement(stmt); err != nil { - return nil, err + tx := d.Transaction() + + if db.Conf.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = d.prepareStatement(stmt, args); err != nil { + return nil, err + } + defer p.Close() + + row = p.QueryRow(args...) + return + } + + query, args = d.compileStatement(stmt, args) + if tx != nil { + row = tx.(*sqlTx).QueryRow(query, args...) + return } - defer p.Close() - row, err = p.QueryRow(args...), nil + row = d.sess.QueryRow(query, args...) return } @@ -348,12 +437,20 @@ func (d *database) Driver() interface{} { return d.sess } -// prepareStatement converts a *exql.Statement representation into an actual -// *sql.Stmt. This method will attempt to used a cached prepared statement, if -// available. -func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) { - if d.sess == nil && d.Transaction() == nil { - return nil, "", db.ErrNotConnected +// compileStatement compiles the given statement into a string. +func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + return d.PartialDatabase.CompileStatement(stmt, args) +} + +// prepareStatement compiles a query and tries to use previously generated +// statement. +func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { + d.sessMu.Lock() + defer d.sessMu.Unlock() + + sess, tx := d.sess, d.Transaction() + if sess == nil && tx == nil { + return nil, "", nil, db.ErrNotConnected } pc, ok := d.cachedStatements.ReadRaw(stmt) @@ -361,26 +458,28 @@ func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) // The statement was cached. ps, err := pc.(*Stmt).Open() if err == nil { - return ps, ps.query, nil + _, args = d.compileStatement(stmt, args) + return ps, ps.query, args, nil } } - // Plain SQL query. - query := d.PartialDatabase.CompileStatement(stmt) - - sqlStmt, err := func() (*sql.Stmt, error) { - if d.Transaction() != nil { - return d.Transaction().(*sqlTx).Prepare(query) + query, args := d.compileStatement(stmt, args) + sqlStmt, err := func(query *string) (*sql.Stmt, error) { + if tx != nil { + return tx.(*sqlTx).Prepare(*query) } - return d.sess.Prepare(query) - }() + return sess.Prepare(*query) + }(&query) if err != nil { - return nil, query, err + return nil, "", nil, err } - p := NewStatement(sqlStmt, query) + p, err := NewStatement(sqlStmt, query).Open() + if err != nil { + return nil, query, args, err + } d.cachedStatements.Write(stmt, p) - return p, query, nil + return p, p.query, args, nil } var waitForConnMu sync.Mutex diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go index a57a3aede82d7225b2bd605e14a03963597feeea..17e7c6d70416bd9518ba31888fdd6af15a721e83 100644 --- a/internal/sqladapter/statement.go +++ b/internal/sqladapter/statement.go @@ -3,6 +3,7 @@ package sqladapter import ( "database/sql" "errors" + "sync" "sync/atomic" ) @@ -10,21 +11,16 @@ var ( activeStatements int64 ) -// NumActiveStatements returns the number of prepared statements in use at any -// point. -func NumActiveStatements() int64 { - return atomic.LoadInt64(&activeStatements) -} - // Stmt represents a *sql.Stmt that is cached and provides the // OnPurge method to allow it to clean after itself. type Stmt struct { *sql.Stmt query string + mu sync.Mutex count int64 - dead int32 + dead bool } // NewStatement creates an returns an opened statement @@ -32,43 +28,58 @@ func NewStatement(stmt *sql.Stmt, query string) *Stmt { s := &Stmt{ Stmt: stmt, query: query, - count: 1, } - // Increment active statements counter. atomic.AddInt64(&activeStatements, 1) return s } // Open marks the statement as in-use func (c *Stmt) Open() (*Stmt, error) { - if atomic.LoadInt32(&c.dead) > 0 { + c.mu.Lock() + defer c.mu.Unlock() + + if c.dead { return nil, errors.New("statement is dead") } - atomic.AddInt64(&c.count, 1) + + c.count++ return c, nil } // Close closes the underlying statement if no other go-routine is using it. -func (c *Stmt) Close() { - if atomic.AddInt64(&c.count, -1) > 0 { - // If this counter is more than 0 then there are other goroutines using - // this statement so we don't want to close it for real. - return - } +func (c *Stmt) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + c.count-- - if atomic.LoadInt32(&c.dead) > 0 && atomic.LoadInt64(&c.count) <= 0 { + return c.checkClose() +} + +func (c *Stmt) checkClose() error { + if c.dead && c.count == 0 { // Statement is dead and we can close it for real. - c.Stmt.Close() + err := c.Stmt.Close() + if err != nil { + return err + } // Reduce active statements counter. atomic.AddInt64(&activeStatements, -1) } + return nil } // OnPurge marks the statement as ready to be cleaned up. func (c *Stmt) OnPurge() { - // Mark as dead, you can continue using it but it will be closed for real - // when c.count reaches 0. - atomic.StoreInt32(&c.dead, 1) - // Call Close again to make sure we're closing the statement. - c.Close() + c.mu.Lock() + defer c.mu.Unlock() + + c.dead = true + c.checkClose() +} + +// NumActiveStatements returns the global number of prepared statements in use +// at any point. +func NumActiveStatements() int64 { + return atomic.LoadInt64(&activeStatements) } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 41150ddfa5b81f94fdaaf83a60d5dc4f7b22d2d9..34671fba8f47a5393683833bbd776f86722033fc 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -17,7 +17,6 @@ import ( "github.com/stretchr/testify/assert" "upper.io/db.v2" - "upper.io/db.v2/internal/sqladapter" "upper.io/db.v2/lib/sqlbuilder" ) @@ -79,6 +78,9 @@ func TestOpenMustSucceed(t *testing.T) { func TestPreparedStatementsCache(t *testing.T) { sess := mustOpen() + db.Conf.SetPreparedStatementCache(true) + defer db.Conf.SetPreparedStatementCache(false) + var tMu sync.Mutex tFatal := func(err error) { tMu.Lock() @@ -86,70 +88,74 @@ func TestPreparedStatementsCache(t *testing.T) { t.Fatal(err) } - // QL and SQLite don't have the same concurrency capabilities PostgreSQL and - // MySQL have, so they have special limits. - defaultLimit := 1000 - - limits := map[string]int { - "sqlite": 20, - "ql": 20, - } - - limit := limits[Adapter] - if limit < 1 { - limit = defaultLimit - } - - // The max number of elements we can have on our LRU is 128, if an statement - // is evicted it will be marked as dead and will be closed only when no other - // queries are using it. - const maxPreparedStatements = 128 * 2 - + // This limit was chosen because, by default, MySQL accepts 16k statements + // and dies. See https://github.com/upper/db/issues/287 + limit := 20000 var wg sync.WaitGroup + for i := 0; i < limit; i++ { wg.Add(1) go func(i int) { defer wg.Done() // This query is different with each iteration and thus generates a new // prepared statement everytime it's called. - res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("count(%d)", i%200))) + res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("count(%d)", i))) var count map[string]uint64 err := res.One(&count) if err != nil { tFatal(err) } - if activeStatements := sqladapter.NumActiveStatements(); activeStatements > maxPreparedStatements { - tFatal(fmt.Errorf("The number of active statements cannot exceed %d (got %d).", maxPreparedStatements, activeStatements)) - } }(i) - if i%50 == 0 { - wg.Wait() - } } wg.Wait() + // Concurrent Insert can open many connections on MySQL / PostgreSQL, this + // sets a limit on them. + sess.SetMaxOpenConns(100) + + switch Adapter { + case "ql": + limit = 1000 + case "sqlite": + // TODO: We'll probably be able to workaround this with a mutex on inserts. + t.Skip(`Skipped due to a "database is locked" problem with concurrent transactions. See https://github.com/mattn/go-sqlite3/issues/274`) + } + for i := 0; i < limit; i++ { wg.Add(1) go func(i int) { defer wg.Done() - // This query is different with each iteration and thus generates a new - // prepared statement everytime it's called. + // The same prepared query on every iteration. _, err := sess.Collection("artist").Insert(artistType{ - Name: fmt.Sprintf("artist-%d", i%200), - }) + Name: fmt.Sprintf("artist-%d", i), + }) if err != nil { tFatal(err) } - if activeStatements := sqladapter.NumActiveStatements(); activeStatements > maxPreparedStatements { - tFatal(fmt.Errorf("The number of active statements cannot exceed %d (got %d).", maxPreparedStatements, activeStatements)) + }(i) + } + wg.Wait() + + // Insert returning creates a transaction. + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // The same prepared query on every iteration. + artist := artistType{ + Name: fmt.Sprintf("artist-%d", i), + } + err := sess.Collection("artist").InsertReturning(&artist) + if err != nil { + tFatal(err) } }(i) - if i%50 == 0 { - wg.Wait() - } } wg.Wait() + // Removing the limit. + sess.SetMaxOpenConns(0) + assert.NoError(t, cleanUpCheck(sess)) assert.NoError(t, sess.Close()) } @@ -532,7 +538,7 @@ func TestGetResultsOneByOne(t *testing.T) { assert.Equal(t, 4, len(allRowsMap)) for _, singleRowMap := range allRowsMap { - if fmt.Sprintf("%d", singleRowMap["id"]) == "0" { + if fmt.Sprintf("%d", singleRowMap["id"]) == "0" { t.Fatalf("Expecting a not null ID.") } } @@ -1020,6 +1026,7 @@ func TestCompositeKeys(t *testing.T) { // Attempts to test database transactions. func TestTransactionsAndRollback(t *testing.T) { + if Adapter == "ql" { t.Skip("Currently not supported.") } @@ -1048,8 +1055,12 @@ func TestTransactionsAndRollback(t *testing.T) { err = tx.Close() assert.NoError(t, err) + err = tx.Close() + assert.NoError(t, err) + // Use another transaction. tx, err = sess.NewTx() + assert.NoError(t, err) artist = tx.Collection("artist") diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 2f953fc53e243285683d9a2f232ffa998ca038cb..1867b221c0d5e582864bf56d20bd4246f0499be5 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -26,6 +26,11 @@ var defaultMapOptions = MapOptions{ IncludeNil: false, } +type compilable interface { + Compile() string + Arguments() []interface{} +} + type hasStringer interface { Stringer() *stringer } @@ -117,6 +122,8 @@ func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, e return b.sess.StatementExec(q, args...) case string: return b.sess.StatementExec(exql.RawSQL(q), args...) + case db.RawValue: + return b.Exec(q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -128,6 +135,8 @@ func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, e return b.sess.StatementQuery(q, args...) case string: return b.sess.StatementQuery(exql.RawSQL(q), args...) + case db.RawValue: + return b.Query(q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -139,6 +148,8 @@ func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, return b.sess.StatementQueryRow(q, args...) case string: return b.sess.StatementQueryRow(exql.RawSQL(q), args...) + case db.RawValue: + return b.QueryRow(q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -320,7 +331,7 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err for i := 0; i < l; i++ { switch v := columns[i].(type) { case *selector: - expanded, rawArgs := expandPlaceholders(v.Compile(), v.Arguments()...) + expanded, rawArgs := expandPlaceholders(v.Compile(), v.Arguments()) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case db.Function: @@ -330,11 +341,11 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) f[i] = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: - expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()...) + expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case exql.Fragment: diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index f4345aeaca21fcca92c4dcdfb9c345bff80e6649..bc5d8f5fa8b649e5cdc5390bed6f613378e577b4 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "database/sql/driver" "fmt" "reflect" "strings" @@ -24,46 +25,28 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils { return &templateWithUtils{template} } -func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) { +func expandQuery(in string, args []interface{}, fn func(interface{}) (string, []interface{})) (string, []interface{}) { argn := 0 argx := make([]interface{}, 0, len(args)) for i := 0; i < len(in); i++ { - if in[i] == '?' { - if len(args) > argn { - k := `?` - - values, isSlice := toInterfaceArguments(args[argn]) - if isSlice { - if len(values) == 0 { - k = `(NULL)` - } else { - k = `(?` + strings.Repeat(`, ?`, len(values)-1) + `)` - } - } else { - if len(values) == 1 { - switch t := values[0].(type) { - case db.RawValue: - k, values = t.Raw(), nil - case *selector: - k, values = `(`+t.Compile()+`)`, t.Arguments() - } - } else if len(values) == 0 { - k = `NULL` - } - } - - if k != `?` { - in = in[:i] + k + in[i+1:] - i += len(k) - 1 - } - - if len(values) > 0 { - argx = append(argx, values...) - } - argn++ + if in[i] != '?' { + continue + } + if len(args) > argn { + k, values := fn(args[argn]) + if k != "" { + in = in[:i] + k + in[i+1:] + i += len(k) - 1 + } + if len(values) > 0 { + argx = append(argx, values...) } + argn++ } } + if len(argx) < len(args) { + argx = append(argx, args[argn:]...) + } return in, argx } @@ -97,6 +80,11 @@ func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) return nil, false } + switch t := value.(type) { + case driver.Valuer: + return []interface{}{t}, false + } + if v.Type().Kind() == reflect.Slice { var i, total int @@ -151,6 +139,39 @@ func toColumnsValuesAndArguments(columnNames []string, columnValues []interface{ return columns, values, arguments, nil } +func preprocessFn(arg interface{}) (string, []interface{}) { + values, isSlice := toInterfaceArguments(arg) + + if isSlice { + if len(values) == 0 { + return `(NULL)`, nil + } + return `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`, values + } + + if len(values) == 1 { + switch t := arg.(type) { + case db.RawValue: + return Preprocess(t.Raw(), t.Arguments()) + case compilable: + return `(` + t.Compile() + `)`, t.Arguments() + } + } else if len(values) == 0 { + return `NULL`, nil + } + + return "", []interface{}{arg} +} + +func Preprocess(in string, args []interface{}) (string, []interface{}) { + return expandQuery(in, args, preprocessFn) +} + +func expandPlaceholders(in string, args []interface{}) (string, []interface{}) { + // TODO: Remove after immutable query builder + return in, args +} + // toWhereWithArguments converts the given parameters into a exql.Where // value. func toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { @@ -161,7 +182,7 @@ func toWhereWithArguments(term interface{}) (where exql.Where, args []interface{ if len(t) > 0 { if s, ok := t[0].(string); ok { if strings.ContainsAny(s, "?") || len(t) == 1 { - s, args = expandPlaceholders(s, t[1:]...) + s, args = expandPlaceholders(s, t[1:]) where.Conditions = []exql.Fragment{exql.RawValue(s)} } else { var val interface{} @@ -190,7 +211,7 @@ func toWhereWithArguments(term interface{}) (where exql.Where, args []interface{ } return case db.RawValue: - r, v := expandPlaceholders(t.Raw(), t.Arguments()...) + r, v := expandPlaceholders(t.Raw(), t.Arguments()) where.Conditions = []exql.Fragment{exql.RawValue(r)} args = append(args, v...) return @@ -308,11 +329,11 @@ func toColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) // A function with one or more arguments. fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) columnValue.Value = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: - expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...) + expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()) columnValue.Value = exql.RawValue(expanded) args = append(args, rawArgs...) default: diff --git a/lib/sqlbuilder/placeholder_test.go b/lib/sqlbuilder/placeholder_test.go index 3f05da3912e09c504e46f3775f576717f170caec..80917b719549512a4d1051c3d55a66da0ca9a03c 100644 --- a/lib/sqlbuilder/placeholder_test.go +++ b/lib/sqlbuilder/placeholder_test.go @@ -9,74 +9,74 @@ import ( func TestPlaceholderSimple(t *testing.T) { { - ret, _ := expandPlaceholders("?", 1) + ret, _ := Preprocess("?", []interface{}{1}) assert.Equal(t, "?", ret) } { - ret, _ := expandPlaceholders("?") + ret, _ := Preprocess("?", nil) assert.Equal(t, "?", ret) } } func TestPlaceholderMany(t *testing.T) { { - ret, _ := expandPlaceholders("?, ?, ?", 1, 2, 3) + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, 3}) assert.Equal(t, "?, ?, ?", ret) } } func TestPlaceholderArray(t *testing.T) { { - ret, _ := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) assert.Equal(t, "?, ?, (?, ?, ?)", ret) } { - ret, _ := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + ret, _ := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) assert.Equal(t, "(?, ?, ?), ?, ?", ret) } { - ret, _ := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + ret, _ := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) assert.Equal(t, "?, (?, ?, ?), ?", ret) } { - ret, _ := expandPlaceholders("???", 1, []interface{}{2, 3, 4}, 5) + ret, _ := Preprocess("???", []interface{}{1, []interface{}{2, 3, 4}, 5}) assert.Equal(t, "?(?, ?, ?)?", ret) } { - ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}) + ret, _ := Preprocess("??", []interface{}{[]interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}}) assert.Equal(t, "(?, ?, ?)(NULL)", ret) } } func TestPlaceholderArguments(t *testing.T) { { - _, args := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + _, args := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } { - _, args := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + _, args := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } { - _, args := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + _, args := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } { - _, args := expandPlaceholders("?, ?", []interface{}{1, 2, 3}, []interface{}{4, 5}) + _, args := Preprocess("?, ?", []interface{}{[]interface{}{1, 2, 3}, []interface{}{4, 5}}) assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) } } func TestPlaceholderReplace(t *testing.T) { { - ret, args := expandPlaceholders("?, ?, ?", 1, db.Raw("foo"), 3) + ret, args := Preprocess("?, ?, ?", []interface{}{1, db.Raw("foo"), 3}) assert.Equal(t, "?, foo, ?", ret) assert.Equal(t, []interface{}{1, 3}, args) } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 72d6f7425bf60dcf55b00d6ae2e84b24eed3fd04..b22d05a4f4ae1f63ba3e4455a05c71033ad2b82a 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -243,7 +243,7 @@ func (sel *selector) OrderBy(columns ...interface{}) Selector { switch value := columns[i].(type) { case db.RawValue: - col, args := expandPlaceholders(value.Raw(), value.Arguments()...) + col, args := expandPlaceholders(value.Raw(), value.Arguments()) sort = &exql.SortColumn{ Column: exql.RawValue(col), } @@ -255,7 +255,7 @@ func (sel *selector) OrderBy(columns ...interface{}) Selector { } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) sort = &exql.SortColumn{ Column: exql.RawValue(expanded), } diff --git a/mongo/database.go b/mongo/database.go index f96472321ed40b4670286b8275db438984bbf50a..1bb48287748f6dd579500b5ab7bfaff06a411d16 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -65,6 +65,21 @@ func (s *Source) ConnectionURL() db.ConnectionURL { return s.connURL } +// SetConnMaxLifetime is not supported. +func (s *Source) SetConnMaxLifetime(time.Duration) { + +} + +// SetMaxIdleConns is not supported. +func (s *Source) SetMaxIdleConns(int) { + +} + +// SetMaxOpenConns is not supported. +func (s *Source) SetMaxOpenConns(int) { + +} + // Name returns the name of the database. func (s *Source) Name() string { return s.name diff --git a/mysql/Makefile b/mysql/Makefile index f6271efc87d8c5ce8adb1068460ec5b2f1d7ef8e..4b14074403bc296d779dee2ade37d054aa084962 100644 --- a/mysql/Makefile +++ b/mysql/Makefile @@ -34,4 +34,5 @@ reset-db: require-client mysql -uroot -h"$(DB_HOST)" -P$(DB_PORT) <<< $$SQL test: reset-db generate - go test -tags generated -v -race + #go test -tags generated -v -race # race: limit on 8192 simultaneously alive goroutines is exceeded, dying + go test -tags generated -v diff --git a/mysql/adapter_test.go b/mysql/adapter_test.go index ca969930905c05d077940ca7b1397c2bac3db2c7..278bc167911830594e38e48d68ff3e23ee292177 100644 --- a/mysql/adapter_test.go +++ b/mysql/adapter_test.go @@ -164,8 +164,8 @@ func cleanUpCheck(sess sqlbuilder.Database) (err error) { return err } - if stats["Prepared_stmt_count"] > 128 { - return fmt.Errorf(`Expecting "Prepared_stmt_count" not to be greater than the prepared statements cache size (128) before cleaning, got %d`, stats["Prepared_stmt_count"]) + if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 { + return fmt.Errorf("Expecting active statements to be at most 128, got %d", activeStatements) } sess.ClearCache() diff --git a/mysql/database.go b/mysql/database.go index f8c2479cd220675f1becf82a4aad71c024e53ad0..2cb66f7345ebb1975124604dafa19df84ffa6a6c 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -133,25 +133,24 @@ func (d *database) clone() (*database, error) { return nil, err } - clone.BaseDatabase = sqladapter.NewBaseDatabase(clone) - - b, err := sqlbuilder.WithSession(clone.BaseDatabase, template) + clone.BaseDatabase, err = d.BindClone(clone) if err != nil { return nil, err } - clone.Builder = b - if err = clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil { + b, err := sqlbuilder.WithSession(clone.BaseDatabase, template) + if err != nil { return nil, err } + clone.Builder = b return clone, nil } // CompileStatement allows sqladapter to compile the given statement into the // format MySQL expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return stmt.Compile(template) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + return sqlbuilder.Preprocess(stmt.Compile(template), args) } // Err allows sqladapter to translate some known errors into generic errors. diff --git a/postgresql/Makefile b/postgresql/Makefile index ab9f9d90438370aea029c2e01e6d2f1a66bce90c..70f59ec6a748f4494e36e7904c6f5dbec4b593a8 100644 --- a/postgresql/Makefile +++ b/postgresql/Makefile @@ -41,4 +41,5 @@ reset-db: require-client fi test: reset-db generate - go test -tags generated -v -race + #go test -tags generated -v -race # race: limit on 8192 simultaneously alive goroutines is exceeded, dying + go test -tags generated -v diff --git a/postgresql/adapter_test.go b/postgresql/adapter_test.go index 4c3b1cbc16ff9280820780c46ffef8bdcf9d6281..72ab2f253f32b11954885d0d0bdf306ae47085dd 100644 --- a/postgresql/adapter_test.go +++ b/postgresql/adapter_test.go @@ -30,6 +30,7 @@ import ( "github.com/stretchr/testify/assert" "upper.io/db.v2" + "upper.io/db.v2/internal/sqladapter" "upper.io/db.v2/lib/sqlbuilder" ) @@ -362,8 +363,8 @@ func cleanUpCheck(sess sqlbuilder.Database) (err error) { return err } - if stats["Prepared_stmt_count"] > 128 { - return fmt.Errorf(`Expecting "Prepared_stmt_count" not to be greater than the prepared statements cache size (128) before cleaning, got %d`, stats["Prepared_stmt_count"]) + if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 { + return fmt.Errorf("Expecting active statements to be at most 128, got %d", activeStatements) } sess.ClearCache() diff --git a/postgresql/database.go b/postgresql/database.go index 7920207e7001ee9e0cf0cd6d725639660b10d3f0..8b92ba2de4431a108de5c934eb25cc038e90f110 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -132,7 +132,10 @@ func (d *database) clone() (*database, error) { return nil, err } - clone.BaseDatabase = sqladapter.NewBaseDatabase(clone) + clone.BaseDatabase, err = d.BindClone(clone) + if err != nil { + return nil, err + } b, err := sqlbuilder.WithSession(clone.BaseDatabase, template) if err != nil { @@ -140,16 +143,14 @@ func (d *database) clone() (*database, error) { } clone.Builder = b - if err = clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil { - return nil, err - } return clone, nil } // CompileStatement allows sqladapter to compile the given statement into the // format PostgreSQL expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return sqladapter.ReplaceWithDollarSign(stmt.Compile(template)) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + query, args := sqlbuilder.Preprocess(stmt.Compile(template), args) + return sqladapter.ReplaceWithDollarSign(query), args } // Err allows sqladapter to translate some known errors into generic errors. diff --git a/postgresql/local_test.go b/postgresql/local_test.go index b09e86ad3e0dbb1c60f3c1401ae2ecb38b561935..82a9069e542b55f50cdaac63ed456be6fc4399d6 100644 --- a/postgresql/local_test.go +++ b/postgresql/local_test.go @@ -117,3 +117,92 @@ func TestIssue210(t *testing.T) { _, err = sess.Collection("hello").Find().Count() assert.NoError(t, err) } + +func TestNonTrivialSubqueries(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + { + q, err := sess.Query(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + ) + + assert.NoError(t, err) + assert.NotNil(t, q) + + assert.True(t, q.Next()) + + var number int + assert.NoError(t, q.Scan(&number)) + + assert.Equal(t, 1, number) + assert.NoError(t, q.Close()) + } + + { + row, err := sess.QueryRow(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + ) + + assert.NoError(t, err) + assert.NotNil(t, row) + + var number int + assert.NoError(t, row.Scan(&number)) + + assert.Equal(t, 1, number) + } + + { + res, err := sess.Exec(`UPDATE artist a1 SET id = ?`, + sess.Select(db.Raw("id + 1")).From("artist a2").Where("a2.id = a1.id"), + ) + + assert.NoError(t, err) + assert.NotNil(t, res) + } + + { + q, err := sess.Query(db.Raw(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + )) + + assert.NoError(t, err) + assert.NotNil(t, q) + + assert.True(t, q.Next()) + + var number int + assert.NoError(t, q.Scan(&number)) + + assert.Equal(t, 2, number) + assert.NoError(t, q.Close()) + } + + { + row, err := sess.QueryRow(db.Raw(`WITH test AS (?) ?`, + sess.Select("id AS foo").From("artist"), + sess.Select("foo").From("test").Where("foo > ?", 0), + )) + + assert.NoError(t, err) + assert.NotNil(t, row) + + var number int + assert.NoError(t, row.Scan(&number)) + + assert.Equal(t, 2, number) + } + + { + res, err := sess.Exec(db.Raw(`UPDATE artist a1 SET id = ?`, + sess.Select(db.Raw("id + 1")).From("artist a2").Where("a2.id = a1.id"), + )) + + assert.NoError(t, err) + assert.NotNil(t, res) + } +} diff --git a/ql/Makefile b/ql/Makefile index 892697d85ccb19afcc8e197d34132ae83432d1b6..25e3c9a862f4f7cc6c8e66fc6b2e3734d04f5236 100644 --- a/ql/Makefile +++ b/ql/Makefile @@ -21,4 +21,5 @@ reset-db: require-client rm -f $(DB_NAME) test: reset-db generate - go test -tags generated -v + #go test -tags generated -v -race # race: limit on 8192 simultaneously alive goroutines is exceeded, dying + go test -tags generated -timeout 30m -v diff --git a/ql/database.go b/ql/database.go index 3a5134cccf59d8352b55f6b466457fb568d1ec5b..59e2d48ff3a6e48e41a6dab53db8c92fa3306f6d 100644 --- a/ql/database.go +++ b/ql/database.go @@ -211,17 +211,25 @@ func (d *database) clone() (*database, error) { return nil, err } - if err := clone.open(); err != nil { + clone.BaseDatabase, err = d.BindClone(clone) + if err != nil { + return nil, err + } + + b, err := sqlbuilder.WithSession(clone.BaseDatabase, template) + if err != nil { return nil, err } + clone.Builder = b return clone, nil } // CompileStatement allows sqladapter to compile the given statement into the // format SQLite expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return sqladapter.ReplaceWithDollarSign(stmt.Compile(template)) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + query, args := sqlbuilder.Preprocess(stmt.Compile(template), args) + return sqladapter.ReplaceWithDollarSign(query), args } // Err allows sqladapter to translate some known errors into generic errors. @@ -235,29 +243,25 @@ func (d *database) Err(err error) error { } // StatementExec wraps the statement to execute around a transaction. -func (d *database) StatementExec(stmt *sql.Stmt, args ...interface{}) (sql.Result, error) { - if d.BaseDatabase.Transaction() == nil { - var tx *sql.Tx - var res sql.Result - var err error - - if tx, err = d.Session().Begin(); err != nil { - return nil, err - } - - s := tx.Stmt(stmt) +func (d *database) StatementExec(query string, args ...interface{}) (res sql.Result, err error) { + if d.Transaction() != nil { + return d.Driver().(*sql.Tx).Exec(query, args...) + } - if res, err = s.Exec(args...); err != nil { - return nil, err - } + sqlTx, err := d.Session().Begin() + if err != nil { + return nil, err + } - if err = tx.Commit(); err != nil { - return nil, err - } + if res, err = sqlTx.Exec(query, args...); err != nil { + return nil, err + } - return res, err + if err = sqlTx.Commit(); err != nil { + return nil, err } - return stmt.Exec(args...) + + return res, err } // NewLocalCollection allows sqladapter create a local db.Collection. diff --git a/sqlite/Makefile b/sqlite/Makefile index 84a5830580d18032f4109a7a786e7291267eb904..44045e0cbc82a4752b3ba31515227d3b44b3a177 100644 --- a/sqlite/Makefile +++ b/sqlite/Makefile @@ -21,4 +21,5 @@ reset-db: require-client rm -f $(DB_NAME) test: reset-db generate - go test -tags generated -v -race + #go test -tags generated -v -race # race: limit on 8192 simultaneously alive goroutines is exceeded, dying + go test -tags generated -v diff --git a/sqlite/connection.go b/sqlite/connection.go index 449d39e078851bb6dfbe339191c1ce35c0ed2822..aa060e4901407fa7373a1e3764564118ae34d2bf 100644 --- a/sqlite/connection.go +++ b/sqlite/connection.go @@ -58,6 +58,10 @@ func (c ConnectionURL) String() (s string) { c.Options = map[string]string{} } + if _, ok := c.Options["_busy_timeout"]; !ok { + c.Options["_busy_timeout"] = "10000" + } + // Converting options into URL values. for k, v := range c.Options { vv.Set(k, v) diff --git a/sqlite/connection_test.go b/sqlite/connection_test.go index 05ac13b53e2197120f25f34f78146536d158a9ed..7baa810d389fc4ade0ccfa0aaa20f16360ad69c6 100644 --- a/sqlite/connection_test.go +++ b/sqlite/connection_test.go @@ -40,7 +40,7 @@ func TestConnectionURL(t *testing.T) { absoluteName, _ := filepath.Abs(c.Database) - if c.String() != "file://"+absoluteName { + if c.String() != "file://"+absoluteName+"?_busy_timeout=10000" { t.Fatal(`Test failed, got:`, c.String()) } @@ -50,14 +50,14 @@ func TestConnectionURL(t *testing.T) { "mode": "ro", } - if c.String() != "file://"+absoluteName+"?cache=foobar&mode=ro" { + if c.String() != "file://"+absoluteName+"?_busy_timeout=10000&cache=foobar&mode=ro" { t.Fatal(`Test failed, got:`, c.String()) } // Setting another database. c.Database = "/another/database" - if c.String() != `file:///another/database?cache=foobar&mode=ro` { + if c.String() != `file:///another/database?_busy_timeout=10000&cache=foobar&mode=ro` { t.Fatal(`Test failed, got:`, c.String()) } @@ -82,7 +82,7 @@ func TestParseConnectionURL(t *testing.T) { t.Fatal("If not defined, cache should be shared by default.") } - s = "file:///path/to/my/database.db?mode=ro&cache=foobar" + s = "file:///path/to/my/database.db?_busy_timeout=10000&mode=ro&cache=foobar" if u, err = ParseURL(s); err != nil { t.Fatal(err) diff --git a/sqlite/database.go b/sqlite/database.go index 32c66e84cbc6026b4c8823bd2f9cfc4f3528f72a..96c85d7076c8b5aad702706b52904764dabbb28d 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -153,17 +153,24 @@ func (d *database) clone() (*database, error) { return nil, err } - if err := clone.open(); err != nil { + clone.BaseDatabase, err = d.BindClone(clone) + if err != nil { + return nil, err + } + + b, err := sqlbuilder.WithSession(clone.BaseDatabase, template) + if err != nil { return nil, err } + clone.Builder = b return clone, nil } // CompileStatement allows sqladapter to compile the given statement into the // format SQLite expects. -func (d *database) CompileStatement(stmt *exql.Statement) string { - return stmt.Compile(template) +func (d *database) CompileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { + return sqlbuilder.Preprocess(stmt.Compile(template), args) } // Err allows sqladapter to translate some known errors into generic errors. @@ -176,6 +183,31 @@ func (d *database) Err(err error) error { return err } +// StatementExec wraps the statement to execute around a transaction. +func (d *database) StatementExec(query string, args ...interface{}) (res sql.Result, err error) { + d.txMu.Lock() + defer d.txMu.Unlock() + + if d.Transaction() != nil { + return d.Driver().(*sql.Tx).Exec(query, args...) + } + + sqlTx, err := d.Session().Begin() + if err != nil { + return nil, err + } + + if res, err = sqlTx.Exec(query, args...); err != nil { + return nil, err + } + + if err = sqlTx.Commit(); err != nil { + return nil, err + } + + return res, err +} + // NewLocalCollection allows sqladapter create a local db.Collection. func (d *database) NewLocalCollection(name string) db.Collection { return newTable(d, name) diff --git a/sqlite/tx.go b/sqlite/tx.go index 19948754393290940751aad8dd6a4c297a20fd83..c39e76e0a80f20f5ce9f3e05fc24cd6e83b0d6e9 100644 --- a/sqlite/tx.go +++ b/sqlite/tx.go @@ -22,8 +22,8 @@ package sqlite import ( - "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter" + "upper.io/db.v2/lib/sqlbuilder" ) type tx struct { @@ -31,19 +31,5 @@ type tx struct { } var ( - _ = db.Tx(&tx{}) + _ = sqlbuilder.Tx(&tx{}) ) - -func (t *tx) Commit() error { - if sess := t.Session(); sess != nil { - defer sess.Close() - } - return t.DatabaseTx.Commit() -} - -func (t *tx) Rollback() error { - if sess := t.Session(); sess != nil { - defer sess.Close() - } - return t.DatabaseTx.Rollback() -}