diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go index b8c73938dfdd57874252797d6f50c1f752ee2732..0a8a4482e894fb1d4d86c00b2df387a98f9aaf7e 100644 --- a/internal/sqladapter/statement.go +++ b/internal/sqladapter/statement.go @@ -2,21 +2,17 @@ package sqladapter import ( "database/sql" - "sync" "sync/atomic" ) var ( - statements = make(map[*Stmt]bool) - statementsMu sync.Mutex + activeStatements int64 ) // NumActiveStatements returns the number of prepared statements in use at any // point. -func NumActiveStatements() int { - statementsMu.Lock() - defer statementsMu.Unlock() - return len(statements) +func NumActiveStatements() int64 { + return atomic.LoadInt64(&activeStatements) } // Stmt represents a *sql.Stmt that is cached and provides the @@ -37,9 +33,8 @@ func NewStatement(stmt *sql.Stmt, query string) *Stmt { query: query, count: 1, } - statementsMu.Lock() - statements[s] = true - statementsMu.Unlock() + // Increment active statements counter. + atomic.AddInt64(&activeStatements, 1) return s } @@ -59,10 +54,8 @@ func (c *Stmt) Close() { if atomic.LoadInt32(&c.dead) > 0 { // Statement is dead and we can close it for real. c.Stmt.Close() - - statementsMu.Lock() - delete(statements, c) - statementsMu.Unlock() + // Reduce active statements counter. + atomic.AddInt64(&activeStatements, -1) } } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index eb539c50819cf3724953a5cbe9c1490a88ac2c23..7367a2d45fbf5c456f76a5456d023e0e56b1421e 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -6,19 +6,19 @@ import ( "database/sql" "flag" "fmt" + "log" "math/rand" "os" "strconv" "strings" "sync" "testing" - "log" - "time" + "time" "github.com/stretchr/testify/assert" "upper.io/db.v2" + "upper.io/db.v2/internal/sqladapter" "upper.io/db.v2/lib/sqlbuilder" - "upper.io/db.v2/internal/sqladapter" ) type customLogger struct { @@ -81,41 +81,41 @@ func TestPreparedStatementsCache(t *testing.T) { assert.NoError(t, err) defer sess.Close() - var tMu sync.Mutex + var tMu sync.Mutex tFatal := func(err error) { tMu.Lock() defer tMu.Unlock() t.Fatal(err) } - // The max number of elements we can have on our LRU is 128, if an statement - // is evicted it will be marked as dead and will be closed only when no other - // queries are using it. - const maxPreparedStatements = 128*2 - - var wg sync.WaitGroup - for i := 0; i < 500; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - // This query is different with each iteration and thus generates a new - // prepared statement everytime it's called. - res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i))) - var count map[string]uint64 - err := res.One(&count) - if err != nil { - tFatal(err) - } - if sqladapter.NumActiveStatements() > maxPreparedStatements { - tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements)) - } - }(i) - if i%maxPreparedStatements == 0 { - wg.Wait() - } - } - - wg.Wait() + // The max number of elements we can have on our LRU is 128, if an statement + // is evicted it will be marked as dead and will be closed only when no other + // queries are using it. + const maxPreparedStatements = 128 * 2 + + var wg sync.WaitGroup + for i := 0; i < 500; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // This query is different with each iteration and thus generates a new + // prepared statement everytime it's called. + res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i))) + var count map[string]uint64 + err := res.One(&count) + if err != nil { + tFatal(err) + } + if sqladapter.NumActiveStatements() > maxPreparedStatements { + tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements)) + } + }(i) + if i%maxPreparedStatements == 0 { + wg.Wait() + } + } + + wg.Wait() } func TestTruncateAllCollections(t *testing.T) {