From f51f7d017d966df76aae45b6f311f2810c065d0c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 9 Aug 2016 21:25:25 -0500
Subject: [PATCH] Add support for connection limits.

---
 db.go                                      |  7 +++++
 internal/sqladapter/collection.go          |  2 +-
 internal/sqladapter/testing/adapter.go.tpl | 14 ++++-----
 mysql/database.go                          |  3 ++
 mysql/mysql.go                             | 34 ++++++++++++++++++++++
 postgresql/database.go                     |  3 ++
 postgresql/postgresql.go                   | 34 ++++++++++++++++++++++
 7 files changed, 88 insertions(+), 9 deletions(-)

diff --git a/db.go b/db.go
index 918050cb..a96a3716 100644
--- a/db.go
+++ b/db.go
@@ -79,6 +79,7 @@ package db // import "upper.io/db.v2"
 import (
 	"fmt"
 	"reflect"
+	"time"
 )
 
 // Constraint interface represents a condition.
@@ -573,3 +574,9 @@ var (
 	_ Constraint  = &constraint{}
 	_ RawValue    = &rawValue{}
 )
+
+var (
+	DefaultConnMaxLifetime = time.Duration(0)
+	DefaultMaxIdleConns    = 0
+	DefaultMaxOpenConns    = 0
+)
diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go
index c616cf4c..14d4ac8a 100644
--- a/internal/sqladapter/collection.go
+++ b/internal/sqladapter/collection.go
@@ -82,10 +82,10 @@ func (c *collection) InsertReturning(item interface{}) error {
 		// Not within a transaction, let's create one.
 		var err error
 		tx, err = c.p.Database().NewLocalTransaction()
-		defer tx.Close()
 		if err != nil {
 			return err
 		}
+		defer tx.Close()
 	}
 
 	var res db.Result
diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index 6c68a922..f9d1af5a 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -830,8 +830,7 @@ func TestDelete(t *testing.T) {
 
 func TestCompositeKeys(t *testing.T) {
 	if Adapter == "ql" {
-		t.Logf("Unsupported, skipped")
-		return
+		t.Skip("Currently not supported.")
 	}
 
 	sess := mustOpen()
@@ -864,8 +863,7 @@ func TestCompositeKeys(t *testing.T) {
 // Attempts to test database transactions.
 func TestTransactionsAndRollback(t *testing.T) {
 	if Adapter == "ql" {
-		t.Logf("Skipped.")
-		return
+		t.Skip("Currently not supported.")
 	}
 
 	sess := mustOpen()
@@ -986,8 +984,7 @@ func TestTransactionsAndRollback(t *testing.T) {
 
 func TestDataTypes(t *testing.T) {
 	if Adapter == "ql" {
-		t.Logf("Skipped.")
-		return
+		t.Skip("Currently not supported.")
 	}
 
 	type testValuesStruct struct {
@@ -1124,11 +1121,11 @@ func TestBuilder(t *testing.T) {
 
 func TestExhaustConnectionPool(t *testing.T) {
 	if Adapter == "ql" {
-		t.Logf("Skipped.")
-		return
+		t.Skip("Currently not supported.")
 	}
 
 	var tMu sync.Mutex
+
 	tFatal := func(err error) {
 		tMu.Lock()
 		defer tMu.Unlock()
@@ -1154,6 +1151,7 @@ func TestExhaustConnectionPool(t *testing.T) {
 
 			// Requesting a new transaction session.
 			start := time.Now()
+			tLogf("Tx: %d: NewTx")
 			tx, err := sess.NewTx()
 			if err != nil {
 				tFatal(err)
diff --git a/mysql/database.go b/mysql/database.go
index 798aa4a9..c660f14d 100644
--- a/mysql/database.go
+++ b/mysql/database.go
@@ -112,6 +112,9 @@ func (d *database) open() error {
 	connFn := func() error {
 		sess, err := sql.Open("mysql", d.ConnectionURL().String())
 		if err == nil {
+			sess.SetConnMaxLifetime(connMaxLifetime)
+			sess.SetMaxIdleConns(maxIdleConns)
+			sess.SetMaxOpenConns(maxOpenConns)
 			return d.BaseDatabase.BindSession(sess)
 		}
 		return err
diff --git a/mysql/mysql.go b/mysql/mysql.go
index 56d42cf4..470db374 100644
--- a/mysql/mysql.go
+++ b/mysql/mysql.go
@@ -23,6 +23,7 @@ package mysql // import "upper.io/db.v2/mysql"
 
 import (
 	"database/sql"
+	"time"
 
 	"upper.io/db.v2"
 
@@ -30,6 +31,12 @@ import (
 	"upper.io/db.v2/lib/sqlbuilder"
 )
 
+var (
+	connMaxLifetime time.Duration = db.DefaultConnMaxLifetime
+	maxIdleConns    int           = db.DefaultMaxIdleConns
+	maxOpenConns    int           = db.DefaultMaxOpenConns
+)
+
 const sqlDriver = `mysql`
 
 // Adapter is the public name of the adapter.
@@ -102,3 +109,30 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) {
 	}
 	return d, nil
 }
+
+// SetConnMaxLifetime sets the default value to be passed to
+// db.SetConnMaxLifetime.
+func SetConnMaxLifetime(d time.Duration) {
+	connMaxLifetime = d
+}
+
+// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns.
+func SetMaxIdleConns(n int) {
+	if n < 0 {
+		n = 0
+	}
+	maxIdleConns = n
+}
+
+// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns.
+// If the value of maxIdleConns is >= 0 and maxOpenConns is less than
+// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns.
+func SetMaxOpenConns(n int) {
+	if n < 0 {
+		n = 0
+	}
+	if n > maxIdleConns {
+		maxIdleConns = n
+	}
+	maxOpenConns = n
+}
diff --git a/postgresql/database.go b/postgresql/database.go
index 28a61791..0933942e 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -111,6 +111,9 @@ func (d *database) open() error {
 	connFn := func() error {
 		sess, err := sql.Open("postgres", d.ConnectionURL().String())
 		if err == nil {
+			sess.SetConnMaxLifetime(connMaxLifetime)
+			sess.SetMaxIdleConns(maxIdleConns)
+			sess.SetMaxOpenConns(maxOpenConns)
 			return d.BaseDatabase.BindSession(sess)
 		}
 		return err
diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go
index 9612e6aa..0572350a 100644
--- a/postgresql/postgresql.go
+++ b/postgresql/postgresql.go
@@ -23,6 +23,7 @@ package postgresql // import "upper.io/db.v2/postgresql"
 
 import (
 	"database/sql"
+	"time"
 
 	"upper.io/db.v2"
 
@@ -30,6 +31,12 @@ import (
 	"upper.io/db.v2/lib/sqlbuilder"
 )
 
+var (
+	connMaxLifetime time.Duration = db.DefaultConnMaxLifetime
+	maxIdleConns    int           = db.DefaultMaxIdleConns
+	maxOpenConns    int           = db.DefaultMaxOpenConns
+)
+
 const sqlDriver = `postgres`
 
 // Adapter is the public name of the adapter.
@@ -102,3 +109,30 @@ func New(sess *sql.DB) (sqlbuilder.Database, error) {
 	}
 	return d, nil
 }
+
+// SetConnMaxLifetime sets the default value to be passed to
+// db.SetConnMaxLifetime.
+func SetConnMaxLifetime(d time.Duration) {
+	connMaxLifetime = d
+}
+
+// SetMaxIdleConns sets the default value to be passed to db.SetMaxOpenConns.
+func SetMaxIdleConns(n int) {
+	if n < 0 {
+		n = 0
+	}
+	maxIdleConns = n
+}
+
+// SetMaxOpenConns sets the default value to be passed to db.SetMaxOpenConns.
+// If the value of maxIdleConns is >= 0 and maxOpenConns is less than
+// maxIdleConns, then maxIdleConns will be reduced to match maxOpenConns.
+func SetMaxOpenConns(n int) {
+	if n < 0 {
+		n = 0
+	}
+	if n > maxIdleConns {
+		maxIdleConns = n
+	}
+	maxOpenConns = n
+}
-- 
GitLab