diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go index be520be8fbf69f4f0b7bda247d944a3784f355a3..0a8a4482e894fb1d4d86c00b2df387a98f9aaf7e 100644 --- a/internal/sqladapter/statement.go +++ b/internal/sqladapter/statement.go @@ -5,6 +5,16 @@ import ( "sync/atomic" ) +var ( + activeStatements int64 +) + +// NumActiveStatements returns the number of prepared statements in use at any +// point. +func NumActiveStatements() int64 { + return atomic.LoadInt64(&activeStatements) +} + // Stmt represents a *sql.Stmt that is cached and provides the // OnPurge method to allow it to clean after itself. type Stmt struct { @@ -18,11 +28,14 @@ type Stmt struct { // NewStatement creates an returns an opened statement func NewStatement(stmt *sql.Stmt, query string) *Stmt { - return &Stmt{ + s := &Stmt{ Stmt: stmt, query: query, count: 1, } + // Increment active statements counter. + atomic.AddInt64(&activeStatements, 1) + return s } // Open marks the statement as in-use @@ -41,6 +54,8 @@ func (c *Stmt) Close() { if atomic.LoadInt32(&c.dead) > 0 { // Statement is dead and we can close it for real. c.Stmt.Close() + // Reduce active statements counter. + atomic.AddInt64(&activeStatements, -1) } } @@ -49,4 +64,6 @@ func (c *Stmt) OnPurge() { // Mark as dead, you can continue using it but it will be closed for real // when c.count reaches 0. atomic.StoreInt32(&c.dead, 1) + // Call Close again to make sure we're closing the statement. + c.Close() } diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 6cc8dc27c03e8259bb792de77705bc029077d353..7367a2d45fbf5c456f76a5456d023e0e56b1421e 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/assert" "upper.io/db.v2" + "upper.io/db.v2/internal/sqladapter" "upper.io/db.v2/lib/sqlbuilder" ) @@ -75,6 +76,48 @@ func TestOpenMustSucceed(t *testing.T) { assert.NoError(t, err) } +func TestPreparedStatementsCache(t *testing.T) { + sess, err := Open(settings) + assert.NoError(t, err) + defer sess.Close() + + var tMu sync.Mutex + tFatal := func(err error) { + tMu.Lock() + defer tMu.Unlock() + t.Fatal(err) + } + + // The max number of elements we can have on our LRU is 128, if an statement + // is evicted it will be marked as dead and will be closed only when no other + // queries are using it. + const maxPreparedStatements = 128 * 2 + + var wg sync.WaitGroup + for i := 0; i < 500; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // This query is different with each iteration and thus generates a new + // prepared statement everytime it's called. + res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i))) + var count map[string]uint64 + err := res.One(&count) + if err != nil { + tFatal(err) + } + if sqladapter.NumActiveStatements() > maxPreparedStatements { + tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements)) + } + }(i) + if i%maxPreparedStatements == 0 { + wg.Wait() + } + } + + wg.Wait() +} + func TestTruncateAllCollections(t *testing.T) { sess, err := Open(settings) assert.NoError(t, err)