From 54b5584799a8206c53a4447449bc064eda3eeeae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Sun, 16 Oct 2016 08:48:48 -0500
Subject: [PATCH] Use a counter to get the number of prepared statements.

---
 internal/sqladapter/statement.go           | 21 +++----
 internal/sqladapter/testing/adapter.go.tpl | 64 +++++++++++-----------
 2 files changed, 39 insertions(+), 46 deletions(-)

diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go
index b8c73938..0a8a4482 100644
--- a/internal/sqladapter/statement.go
+++ b/internal/sqladapter/statement.go
@@ -2,21 +2,17 @@ package sqladapter
 
 import (
 	"database/sql"
-	"sync"
 	"sync/atomic"
 )
 
 var (
-	statements   = make(map[*Stmt]bool)
-	statementsMu sync.Mutex
+	activeStatements int64
 )
 
 // NumActiveStatements returns the number of prepared statements in use at any
 // point.
-func NumActiveStatements() int {
-	statementsMu.Lock()
-	defer statementsMu.Unlock()
-	return len(statements)
+func NumActiveStatements() int64 {
+	return atomic.LoadInt64(&activeStatements)
 }
 
 // Stmt represents a *sql.Stmt that is cached and provides the
@@ -37,9 +33,8 @@ func NewStatement(stmt *sql.Stmt, query string) *Stmt {
 		query: query,
 		count: 1,
 	}
-	statementsMu.Lock()
-	statements[s] = true
-	statementsMu.Unlock()
+	// Increment active statements counter.
+	atomic.AddInt64(&activeStatements, 1)
 	return s
 }
 
@@ -59,10 +54,8 @@ func (c *Stmt) Close() {
 	if atomic.LoadInt32(&c.dead) > 0 {
 		// Statement is dead and we can close it for real.
 		c.Stmt.Close()
-
-		statementsMu.Lock()
-		delete(statements, c)
-		statementsMu.Unlock()
+		// Reduce active statements counter.
+		atomic.AddInt64(&activeStatements, -1)
 	}
 }
 
diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index eb539c50..7367a2d4 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -6,19 +6,19 @@ import (
 	"database/sql"
 	"flag"
 	"fmt"
+	"log"
 	"math/rand"
 	"os"
 	"strconv"
 	"strings"
 	"sync"
 	"testing"
-  "log"
-  "time"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 	"upper.io/db.v2"
+	"upper.io/db.v2/internal/sqladapter"
 	"upper.io/db.v2/lib/sqlbuilder"
-  "upper.io/db.v2/internal/sqladapter"
 )
 
 type customLogger struct {
@@ -81,41 +81,41 @@ func TestPreparedStatementsCache(t *testing.T) {
 	assert.NoError(t, err)
 	defer sess.Close()
 
-  var tMu sync.Mutex
+	var tMu sync.Mutex
 	tFatal := func(err error) {
 		tMu.Lock()
 		defer tMu.Unlock()
 		t.Fatal(err)
 	}
 
-  // The max number of elements we can have on our LRU is 128, if an statement
-  // is evicted it will be marked as dead and will be closed only when no other
-  // queries are using it.
-  const maxPreparedStatements = 128*2
-
-  var wg sync.WaitGroup
-  for i := 0; i < 500; i++ {
-    wg.Add(1)
-    go func(i int) {
-      defer wg.Done()
-      // This query is different with each iteration and thus generates a new
-      // prepared statement everytime it's called.
-      res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i)))
-      var count map[string]uint64
-      err := res.One(&count)
-      if err != nil {
-        tFatal(err)
-      }
-      if sqladapter.NumActiveStatements() > maxPreparedStatements {
-        tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements))
-      }
-    }(i)
-    if i%maxPreparedStatements == 0 {
-      wg.Wait()
-    }
-  }
-
-  wg.Wait()
+	// The max number of elements we can have on our LRU is 128, if an statement
+	// is evicted it will be marked as dead and will be closed only when no other
+	// queries are using it.
+	const maxPreparedStatements = 128 * 2
+
+	var wg sync.WaitGroup
+	for i := 0; i < 500; i++ {
+		wg.Add(1)
+		go func(i int) {
+			defer wg.Done()
+			// This query is different with each iteration and thus generates a new
+			// prepared statement everytime it's called.
+			res := sess.Collection("artist").Find().Select(db.Raw(fmt.Sprintf("COUNT(%d)", i)))
+			var count map[string]uint64
+			err := res.One(&count)
+			if err != nil {
+				tFatal(err)
+			}
+			if sqladapter.NumActiveStatements() > maxPreparedStatements {
+				tFatal(fmt.Errorf("The number of active statements cannot exceed %d.", maxPreparedStatements))
+			}
+		}(i)
+		if i%maxPreparedStatements == 0 {
+			wg.Wait()
+		}
+	}
+
+	wg.Wait()
 }
 
 func TestTruncateAllCollections(t *testing.T) {
-- 
GitLab