diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go index be520be8fbf69f4f0b7bda247d944a3784f355a3..b8c73938dfdd57874252797d6f50c1f752ee2732 100644 --- a/internal/sqladapter/statement.go +++ b/internal/sqladapter/statement.go @@ -2,9 +2,23 @@ package sqladapter import ( "database/sql" + "sync" "sync/atomic" ) +var ( + statements = make(map[*Stmt]bool) + statementsMu sync.Mutex +) + +// NumActiveStatements returns the number of prepared statements in use at any +// point. +func NumActiveStatements() int { + statementsMu.Lock() + defer statementsMu.Unlock() + return len(statements) +} + // Stmt represents a *sql.Stmt that is cached and provides the // OnPurge method to allow it to clean after itself. type Stmt struct { @@ -18,11 +32,15 @@ type Stmt struct { // NewStatement creates an returns an opened statement func NewStatement(stmt *sql.Stmt, query string) *Stmt { - return &Stmt{ + s := &Stmt{ Stmt: stmt, query: query, count: 1, } + statementsMu.Lock() + statements[s] = true + statementsMu.Unlock() + return s } // Open marks the statement as in-use @@ -41,6 +59,10 @@ func (c *Stmt) Close() { if atomic.LoadInt32(&c.dead) > 0 { // Statement is dead and we can close it for real. c.Stmt.Close() + + statementsMu.Lock() + delete(statements, c) + statementsMu.Unlock() } } @@ -49,4 +71,6 @@ func (c *Stmt) OnPurge() { // Mark as dead, you can continue using it but it will be closed for real // when c.count reaches 0. atomic.StoreInt32(&c.dead, 1) + // Call Close again to make sure we're closing the statement. + c.Close() } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 6cc8dc27c03e8259bb792de77705bc029077d353..eb539c50819cf3724953a5cbe9c1490a88ac2c23 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -6,18 +6,19 @@ import ( "database/sql" "flag" "fmt" - "log" "math/rand" "os" "strconv" "strings" "sync" "testing" - "time" + "log" + "time" "github.com/stretchr/testify/assert" "upper.io/db.v2" "upper.io/db.v2/lib/sqlbuilder" + "upper.io/db.v2/internal/sqladapter" ) type customLogger struct { @@ -75,6 +76,48 @@ func TestOpenMustSucceed(t *testing.T) { assert.NoError(t, err) } +func TestPreparedStatementsCache(t *testing.T) { + sess, err := Open(settings) + assert.NoError(t, err) + defer sess.Close() + + var tMu sync.Mutex + tFatal := func(err error) { + tMu.Lock() + defer tMu.Unlock() + t.Fatal(err) + } + + // The max number of elements we can have on our LRU is 128, if an statement + // is evicted it will be marked as dead and will be closed only when no other + // queries are using it. + const maxPreparedStatements = 128*2 + + var wg sync.WaitGroup + for i := 0; i < 500; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // This query is different with each iteration and thus generates a new + // prepared statement everytime it's called. + res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i))) + var count map[string]uint64 + err := res.One(&count) + if err != nil { + tFatal(err) + } + if sqladapter.NumActiveStatements() > maxPreparedStatements { + tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements)) + } + }(i) + if i%maxPreparedStatements == 0 { + wg.Wait() + } + } + + wg.Wait() +} + func TestTruncateAllCollections(t *testing.T) { sess, err := Open(settings) assert.NoError(t, err)