diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 826a5902434244f2dffc5d7a04707d062493f1ab..d81cebe246f6ff0f768cee0d986dbd4d66afbf71 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -83,8 +83,8 @@ func (c *Cache) Read(h Hashable) (string, bool) { // does not exists returns nil and false. func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) { c.mu.RLock() + defer c.mu.RUnlock() data, ok := c.cache[h.Hash()] - c.mu.RUnlock() if ok { return data.Value.(*item).value, true } @@ -106,7 +106,7 @@ func (c *Cache) Write(h Hashable, value interface{}) { c.cache[key] = c.li.PushFront(&item{key, value}) - if c.li.Len() > c.capacity { + for c.li.Len() > c.capacity { el := c.li.Remove(c.li.Back()) delete(c.cache, el.(*item).key) if p, ok := el.(*item).value.(HasOnPurge); ok { diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index ca8477c1e527314181973090210307edda3bce50..fd9a3577efa84637fbc25f7fe47b503ab1934460 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -267,13 +267,14 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res }(time.Now()) } - var p *sql.Stmt + var p *Stmt if p, query, err = d.prepareStatement(stmt); err != nil { return nil, err } + defer p.Close() if execer, ok := d.PartialDatabase.(HasStatementExec); ok { - res, err = execer.StatementExec(p, args...) + res, err = execer.StatementExec(p.Stmt, args...) return } @@ -299,10 +300,11 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro }(time.Now()) } - var p *sql.Stmt + var p *Stmt if p, query, err = d.prepareStatement(stmt); err != nil { return nil, err } + defer p.Close() rows, err = p.Query(args...) return @@ -327,10 +329,11 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) }(time.Now()) } - var p *sql.Stmt + var p *Stmt if p, query, err = d.prepareStatement(stmt); err != nil { return nil, err } + defer p.Close() row, err = p.QueryRow(args...), nil return @@ -348,37 +351,33 @@ func (d *database) Driver() interface{} { // prepareStatement converts a *exql.Statement representation into an actual // *sql.Stmt. This method will attempt to used a cached prepared statement, if // available. -func (d *database) prepareStatement(stmt *exql.Statement) (*sql.Stmt, string, error) { +func (d *database) prepareStatement(stmt *exql.Statement) (*Stmt, string, error) { if d.sess == nil && d.Transaction() == nil { return nil, "", db.ErrNotConnected } pc, ok := d.cachedStatements.ReadRaw(stmt) - if ok { // The statement was cached. - ps := pc.(*cachedStatement) - return ps.Stmt, ps.query, nil + ps := pc.(*Stmt).open() + return ps, ps.query, nil } // Plain SQL query. query := d.PartialDatabase.CompileStatement(stmt) - var p *sql.Stmt - var err error - - if d.Transaction() != nil { - p, err = d.Transaction().(*sqlTx).Prepare(query) - } else { - p, err = d.sess.Prepare(query) - } - + sqlStmt, err := func() (*sql.Stmt, error) { + if d.Transaction() != nil { + return d.Transaction().(*sqlTx).Prepare(query) + } + return d.sess.Prepare(query) + }() if err != nil { return nil, query, err } - d.cachedStatements.Write(stmt, &cachedStatement{p, query}) - + p := newCachedStatement(sqlStmt, query) + d.cachedStatements.Write(stmt, p) return p, query, nil } diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go index 51e4f1d7f0fbecf3996a29483e09b2c504750d83..b9907d60c8db3109380562cf1afe07fe5158b344 100644 --- a/internal/sqladapter/statement.go +++ b/internal/sqladapter/statement.go @@ -2,15 +2,47 @@ package sqladapter import ( "database/sql" + "sync/atomic" ) -// cachedStatement represents a *sql.Stmt that is cached and provides the +// Stmt represents a *sql.Stmt that is cached and provides the // OnPurge method to allow it to clean after itself. -type cachedStatement struct { +type Stmt struct { *sql.Stmt + query string + + count int64 + dead int32 +} + +func newCachedStatement(stmt *sql.Stmt, query string) *Stmt { + return &Stmt{ + Stmt: stmt, + query: query, + count: 1, + } +} + +func (c *Stmt) open() *Stmt { + atomic.AddInt64(&c.count, 1) + return c +} + +func (c *Stmt) Close() { + if atomic.AddInt64(&c.count, -1) > 0 { + // There are another goroutines using this statement so we don't want to + // close it for real. + return + } + if atomic.LoadInt32(&c.dead) > 0 { + // Statement is dead and we can close it for real. + c.Stmt.Close() + } } -func (c *cachedStatement) OnPurge() { - c.Stmt.Close() +func (c *Stmt) OnPurge() { + // Mark as dead, you can continue using it but it will be closed for real + // when c.count reaches 0. + atomic.StoreInt32(&c.dead, 1) } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 92b2e89974a7cf18ef88c23e1444cbc9f8467527..9f185d9016f7a8b5ef992236d092659c4d32fe5d 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -1370,6 +1370,34 @@ func TestBuilder(t *testing.T) { assert.NotZero(t, all) } +func TestStressPreparedStatementCache(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + var tMu sync.Mutex + tFatal := func(err error) { + tMu.Lock() + defer tMu.Unlock() + t.Fatal(err) + } + + var wg sync.WaitGroup + + for i := 1; i < 1000; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i%5))) + var data map[string]interface{} + if err := res.One(&data); err != nil { + tFatal(err) + } + }(i) + } + + wg.Wait() +} + func TestExhaustConnectionPool(t *testing.T) { if Adapter == "ql" { t.Skip("Currently not supported.")