From 16e17b8a2b559c8692c306ac0b6a901e6f77e1f3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Fri, 14 Oct 2016 15:23:22 -0500
Subject: [PATCH] Add NumActiveStatements() and test case to make sure we don't
 leak prepared statements.

---
 internal/sqladapter/statement.go           | 26 +++++++++++-
 internal/sqladapter/testing/adapter.go.tpl | 47 +++++++++++++++++++++-
 2 files changed, 70 insertions(+), 3 deletions(-)

diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go
index be520be8..b8c73938 100644
--- a/internal/sqladapter/statement.go
+++ b/internal/sqladapter/statement.go
@@ -2,9 +2,23 @@ package sqladapter
 
 import (
 	"database/sql"
+	"sync"
 	"sync/atomic"
 )
 
+var (
+	statements   = make(map[*Stmt]bool)
+	statementsMu sync.Mutex
+)
+
+// NumActiveStatements returns the number of prepared statements in use at any
+// point.
+func NumActiveStatements() int {
+	statementsMu.Lock()
+	defer statementsMu.Unlock()
+	return len(statements)
+}
+
 // Stmt represents a *sql.Stmt that is cached and provides the
 // OnPurge method to allow it to clean after itself.
 type Stmt struct {
@@ -18,11 +32,15 @@ type Stmt struct {
 
 // NewStatement creates an returns an opened statement
 func NewStatement(stmt *sql.Stmt, query string) *Stmt {
-	return &Stmt{
+	s := &Stmt{
 		Stmt:  stmt,
 		query: query,
 		count: 1,
 	}
+	statementsMu.Lock()
+	statements[s] = true
+	statementsMu.Unlock()
+	return s
 }
 
 // Open marks the statement as in-use
@@ -41,6 +59,10 @@ func (c *Stmt) Close() {
 	if atomic.LoadInt32(&c.dead) > 0 {
 		// Statement is dead and we can close it for real.
 		c.Stmt.Close()
+
+		statementsMu.Lock()
+		delete(statements, c)
+		statementsMu.Unlock()
 	}
 }
 
@@ -49,4 +71,6 @@ func (c *Stmt) OnPurge() {
 	// Mark as dead, you can continue using it but it will be closed for real
 	// when c.count reaches 0.
 	atomic.StoreInt32(&c.dead, 1)
+	// Call Close again to make sure we're closing the statement.
+	c.Close()
 }
diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index 6cc8dc27..eb539c50 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -6,18 +6,19 @@ import (
 	"database/sql"
 	"flag"
 	"fmt"
-	"log"
 	"math/rand"
 	"os"
 	"strconv"
 	"strings"
 	"sync"
 	"testing"
-	"time"
+  "log"
+  "time"
 
 	"github.com/stretchr/testify/assert"
 	"upper.io/db.v2"
 	"upper.io/db.v2/lib/sqlbuilder"
+  "upper.io/db.v2/internal/sqladapter"
 )
 
 type customLogger struct {
@@ -75,6 +76,48 @@ func TestOpenMustSucceed(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestPreparedStatementsCache(t *testing.T) {
+	sess, err := Open(settings)
+	assert.NoError(t, err)
+	defer sess.Close()
+
+  var tMu sync.Mutex
+	tFatal := func(err error) {
+		tMu.Lock()
+		defer tMu.Unlock()
+		t.Fatal(err)
+	}
+
+  // The max number of elements we can have on our LRU is 128, if an statement
+  // is evicted it will be marked as dead and will be closed only when no other
+  // queries are using it.
+  const maxPreparedStatements = 128*2
+
+  var wg sync.WaitGroup
+  for i := 0; i < 500; i++ {
+    wg.Add(1)
+    go func(i int) {
+      defer wg.Done()
+      // This query is different with each iteration and thus generates a new
+      // prepared statement everytime it's called.
+      res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i)))
+      var count map[string]uint64
+      err := res.One(&count)
+      if err != nil {
+        tFatal(err)
+      }
+      if sqladapter.NumActiveStatements() > maxPreparedStatements {
+        tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements))
+      }
+    }(i)
+    if i%maxPreparedStatements == 0 {
+      wg.Wait()
+    }
+  }
+
+  wg.Wait()
+}
+
 func TestTruncateAllCollections(t *testing.T) {
 	sess, err := Open(settings)
 	assert.NoError(t, err)
-- 
GitLab