From 5de5dfff3a74fe11d45b4721b51581e62051d040 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 22 Nov 2016 10:49:46 -0600
Subject: [PATCH] Add cleanUpCheck to count number of prepared statements.

---
 mysql/adapter_test.go      | 59 ++++++++++++++++++++++++++++++++++++++
 mysql/database.go          |  8 +++++-
 postgresql/adapter_test.go | 42 +++++++++++++++++++++++++++
 postgresql/database.go     |  4 ++-
 ql/adapter_test.go         |  7 +++++
 sqlite/adapter_test.go     |  7 +++++
 sqlite/database.go         |  3 ++
 7 files changed, 128 insertions(+), 2 deletions(-)

diff --git a/mysql/adapter_test.go b/mysql/adapter_test.go
index 0327e6a5..ca969930 100644
--- a/mysql/adapter_test.go
+++ b/mysql/adapter_test.go
@@ -26,6 +26,10 @@ import (
 	"database/sql"
 	"fmt"
 	"os"
+	"time"
+
+	"upper.io/db.v2/internal/sqladapter"
+	"upper.io/db.v2/lib/sqlbuilder"
 )
 
 const (
@@ -131,3 +135,58 @@ func tearUp() error {
 
 	return nil
 }
+
+func getStats(sess sqlbuilder.Database) (map[string]int, error) {
+	stats := make(map[string]int)
+
+	res, err := sess.Driver().(*sql.DB).Query(`SHOW GLOBAL STATUS LIKE '%stmt%'`)
+	if err != nil {
+		return nil, err
+	}
+	var result struct {
+		VariableName string `db:"Variable_name"`
+		Value        int    `db:"Value"`
+	}
+
+	iter := sqlbuilder.NewIterator(res)
+	for iter.Next(&result) {
+		stats[result.VariableName] = result.Value
+	}
+
+	return stats, nil
+}
+
+func cleanUpCheck(sess sqlbuilder.Database) (err error) {
+	var stats map[string]int
+
+	stats, err = getStats(sess)
+	if err != nil {
+		return err
+	}
+
+	if stats["Prepared_stmt_count"] > 128 {
+		return fmt.Errorf(`Expecting "Prepared_stmt_count" not to be greater than the prepared statements cache size (128) before cleaning, got %d`, stats["Prepared_stmt_count"])
+	}
+
+	sess.ClearCache()
+
+	if activeStatements := sqladapter.NumActiveStatements(); activeStatements != 0 {
+		return fmt.Errorf("Expecting active statements to be 0, got %d", activeStatements)
+	}
+
+	for i := 0; i < 10; i++ {
+		stats, err = getStats(sess)
+		if err != nil {
+			return err
+		}
+
+		if stats["Prepared_stmt_count"] != 0 {
+			time.Sleep(time.Millisecond * 200) // Sometimes it takes a bit to clean prepared statements
+			err = fmt.Errorf(`Expecting "Prepared_stmt_count" to be 0, got %d`, stats["Prepared_stmt_count"])
+			continue
+		}
+		break
+	}
+
+	return err
+}
diff --git a/mysql/database.go b/mysql/database.go
index 8263e40c..f8c2479c 100644
--- a/mysql/database.go
+++ b/mysql/database.go
@@ -141,7 +141,10 @@ func (d *database) clone() (*database, error) {
 	}
 	clone.Builder = b
 
-	clone.BaseDatabase.BindSession(d.BaseDatabase.Session())
+	if err = clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil {
+		return nil, err
+	}
+
 	return clone, nil
 }
 
@@ -181,6 +184,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
 		return nil, err
 	}
 
+	clone.txMu.Lock()
+	defer clone.txMu.Unlock()
+
 	connFn := func() error {
 		sqlTx, err := clone.BaseDatabase.Session().Begin()
 		if err == nil {
diff --git a/postgresql/adapter_test.go b/postgresql/adapter_test.go
index 639e3ccf..4c3b1cbc 100644
--- a/postgresql/adapter_test.go
+++ b/postgresql/adapter_test.go
@@ -24,11 +24,13 @@ package postgresql
 
 import (
 	"database/sql"
+	"fmt"
 	"os"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
 	"upper.io/db.v2"
+	"upper.io/db.v2/lib/sqlbuilder"
 )
 
 const (
@@ -336,3 +338,43 @@ func TestOptionTypeJsonbStruct(t *testing.T) {
 	assert.Equal(t, "a", item1Chk.Settings.Name)
 	assert.Equal(t, int64(123), item1Chk.Settings.Num)
 }
+
+func getStats(sess sqlbuilder.Database) (map[string]int, error) {
+	stats := make(map[string]int)
+
+	row := sess.Driver().(*sql.DB).QueryRow(`SELECT count(1) AS value FROM pg_prepared_statements`)
+
+	var value int
+	err := row.Scan(&value)
+	if err != nil {
+		return nil, err
+	}
+
+	stats["pg_prepared_statements_count"] = value
+
+	return stats, nil
+}
+
+func cleanUpCheck(sess sqlbuilder.Database) (err error) {
+	var stats map[string]int
+	stats, err = getStats(sess)
+	if err != nil {
+		return err
+	}
+
+	if stats["Prepared_stmt_count"] > 128 {
+		return fmt.Errorf(`Expecting "Prepared_stmt_count" not to be greater than the prepared statements cache size (128) before cleaning, got %d`, stats["Prepared_stmt_count"])
+	}
+
+	sess.ClearCache()
+
+	stats, err = getStats(sess)
+	if err != nil {
+		return err
+	}
+
+	if stats["pg_prepared_statements_count"] != 0 {
+		return fmt.Errorf(`Expecting "Prepared_stmt_count" to be 0, got %d`, stats["Prepared_stmt_count"])
+	}
+	return nil
+}
diff --git a/postgresql/database.go b/postgresql/database.go
index d7807d99..7920207e 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -140,7 +140,9 @@ func (d *database) clone() (*database, error) {
 	}
 	clone.Builder = b
 
-	clone.BaseDatabase.BindSession(d.BaseDatabase.Session())
+	if err = clone.BaseDatabase.BindSession(d.BaseDatabase.Session()); err != nil {
+		return nil, err
+	}
 	return clone, nil
 }
 
diff --git a/ql/adapter_test.go b/ql/adapter_test.go
index a7a1998c..2715c9b0 100644
--- a/ql/adapter_test.go
+++ b/ql/adapter_test.go
@@ -25,6 +25,8 @@ package ql
 import (
 	"database/sql"
 	"os"
+
+	"upper.io/db.v2/lib/sqlbuilder"
 )
 
 const (
@@ -123,3 +125,8 @@ func tearUp() error {
 
 	return nil
 }
+
+func cleanUpCheck(sess sqlbuilder.Database) (err error) {
+	// TODO: Check the number of prepared statements.
+	return nil
+}
diff --git a/sqlite/adapter_test.go b/sqlite/adapter_test.go
index 10d7a552..28eced22 100644
--- a/sqlite/adapter_test.go
+++ b/sqlite/adapter_test.go
@@ -25,6 +25,8 @@ package sqlite
 import (
 	"database/sql"
 	"os"
+
+	"upper.io/db.v2/lib/sqlbuilder"
 )
 
 const (
@@ -126,3 +128,8 @@ func tearUp() error {
 
 	return nil
 }
+
+func cleanUpCheck(sess sqlbuilder.Database) (err error) {
+	// TODO: Check the number of prepared statements.
+	return nil
+}
diff --git a/sqlite/database.go b/sqlite/database.go
index db824671..32c66e84 100644
--- a/sqlite/database.go
+++ b/sqlite/database.go
@@ -194,6 +194,9 @@ func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) {
 		return nil, err
 	}
 
+	clone.txMu.Lock()
+	defer clone.txMu.Unlock()
+
 	openFn := func() error {
 		sqlTx, err := clone.BaseDatabase.Session().Begin()
 		if err == nil {
-- 
GitLab