From c8b5ee6590f93458f8778054995120e762fc0faf Mon Sep 17 00:00:00 2001
From: Garet Halliday <ghalliday@gfxlabs.io>
Date: Mon, 17 Oct 2022 16:07:02 -0500
Subject: [PATCH] close connection and queue for recreation if conn is
 unrecoverable

---
 lib/gat/gatling/server/server.go        | 23 ++++-----
 lib/gat/interfaces.go                   |  1 +
 lib/gat/pool/session/pool.go            |  4 ++
 lib/gat/pool/transaction/shard/shard.go | 62 +++++++++++++++++++------
 test/docker-compose.yml                 |  1 +
 5 files changed, 62 insertions(+), 29 deletions(-)

diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go
index bd4720b4..f346e0fe 100644
--- a/lib/gat/gatling/server/server.go
+++ b/lib/gat/gatling/server/server.go
@@ -360,18 +360,8 @@ func (s *Server) stabilize() {
 		if err != nil {
 			return
 		}
-		err = s.flush()
-		if err != nil {
-			return
-		}
-	}
-	query := new(protocol.Query)
-	query.Fields.Query = "end"
-	err := s.writePacket(query)
-	if err != nil {
-		return
 	}
-	err = s.flush()
+	err := s.flush()
 	if err != nil {
 		return
 	}
@@ -558,6 +548,8 @@ func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error {
 		if err != nil {
 			return err
 		}
+	default:
+		return fmt.Errorf("don't know how to handle %T", packet)
 	}
 	return nil
 }
@@ -704,9 +696,12 @@ func (s *Server) CallFunction(ctx context.Context, client gat.Client, payload *p
 	return s.sendAndLink(ctx, client, payload)
 }
 
-func (s *Server) Close(ctx context.Context) error {
-	<-ctx.Done()
-	return nil
+func (s *Server) Close() error {
+	err := s.writePacket(&protocol.Close{})
+	if err != nil {
+		return err
+	}
+	return s.conn.Close()
 }
 
 var _ gat.Connection = (*Server)(nil)
diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go
index 0a02ebed..73223061 100644
--- a/lib/gat/interfaces.go
+++ b/lib/gat/interfaces.go
@@ -128,6 +128,7 @@ type Connection interface {
 
 	// IsCloseNeeded returns whether this connection failed a health check
 	IsCloseNeeded() bool
+	Close() error
 
 	// actions
 	Describe(ctx context.Context, client Client, payload *protocol.Describe) error
diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go
index b5fcc587..c4b0c0b5 100644
--- a/lib/gat/pool/session/pool.go
+++ b/lib/gat/pool/session/pool.go
@@ -78,6 +78,10 @@ func (p *Pool) OnDisconnect(client gat.Client) {
 	if !ok {
 		return
 	}
+	if c.IsCloseNeeded() {
+		_ = c.Close()
+		return
+	}
 	p.servers <- c
 }
 
diff --git a/lib/gat/pool/transaction/shard/shard.go b/lib/gat/pool/transaction/shard/shard.go
index fd74e73f..e9803700 100644
--- a/lib/gat/pool/transaction/shard/shard.go
+++ b/lib/gat/pool/transaction/shard/shard.go
@@ -5,13 +5,42 @@ import (
 	"gfx.cafe/gfx/pggat/lib/config"
 	"gfx.cafe/gfx/pggat/lib/gat"
 	"gfx.cafe/gfx/pggat/lib/gat/protocol"
+	"git.tuxpa.in/a/zlog/log"
 	"math/rand"
 	"reflect"
 )
 
+type shardConn struct {
+	conn gat.Connection
+	conf *config.Server
+	s    *Shard
+}
+
+func (s *shardConn) connect() {
+	if s.s == nil || s.conf == nil {
+		return
+	}
+	if s.conn != nil {
+		_ = s.conn.Close()
+	}
+	var err error
+	s.conn, err = s.s.dialer(context.TODO(), s.s.options, s.s.user, s.s.conf, s.conf)
+	if err != nil {
+		log.Println("error connecting to server:", err)
+	}
+	return
+}
+
+func (s *shardConn) acquire() gat.Connection {
+	if s.conn == nil || s.conn.IsCloseNeeded() {
+		s.connect()
+	}
+	return s.conn
+}
+
 type Shard struct {
-	primary  gat.Connection
-	replicas []gat.Connection
+	primary  shardConn
+	replicas []shardConn
 
 	pool *config.Pool
 	user *config.User
@@ -36,19 +65,22 @@ func FromConfig(dialer gat.Dialer, options []protocol.FieldsStartupMessageParame
 	return out
 }
 
+func (s *Shard) newConn(conf *config.Server) shardConn {
+	return shardConn{
+		conf: conf,
+		s:    s,
+	}
+}
+
 func (s *Shard) init() {
-	s.primary = nil
+	s.primary = shardConn{}
 	s.replicas = nil
 	for _, serv := range s.conf.Servers {
-		srv, err := s.dialer(context.TODO(), s.options, s.user, s.conf, serv)
-		if err != nil {
-			continue
-		}
 		switch serv.Role {
 		case config.SERVERROLE_PRIMARY:
-			s.primary = srv
+			s.primary = s.newConn(serv)
 		default:
-			s.replicas = append(s.replicas, srv)
+			s.replicas = append(s.replicas, s.newConn(serv))
 		}
 	}
 }
@@ -56,29 +88,29 @@ func (s *Shard) init() {
 func (s *Shard) Choose(role config.ServerRole) gat.Connection {
 	switch role {
 	case config.SERVERROLE_PRIMARY:
-		return s.primary
+		return s.primary.acquire()
 	case config.SERVERROLE_REPLICA:
 		if len(s.replicas) == 0 {
 			// only return primary if primary reads are enabled
 			if s.pool.PrimaryReadsEnabled {
-				return s.primary
+				return s.primary.acquire()
 			}
 			return nil
 		}
 
 		// read from a random replica
-		return s.replicas[rand.Intn(len(s.replicas))]
+		return s.replicas[rand.Intn(len(s.replicas))].acquire()
 	default:
 		return nil
 	}
 }
 
 func (s *Shard) GetPrimary() gat.Connection {
-	return s.primary
+	return s.Choose(config.SERVERROLE_PRIMARY)
 }
 
-func (s *Shard) GetReplicas() []gat.Connection {
-	return s.replicas
+func (s *Shard) GetReplica() gat.Connection {
+	return s.Choose(config.SERVERROLE_REPLICA)
 }
 
 func (s *Shard) EnsureConfig(c *config.Shard) {
diff --git a/test/docker-compose.yml b/test/docker-compose.yml
index a49d935c..0641271d 100644
--- a/test/docker-compose.yml
+++ b/test/docker-compose.yml
@@ -20,6 +20,7 @@ services:
       PSQL_DB_USER_RO: postgres
       PSQL_DB_PASS_RO: example
       PSQL_PRI_DB_HOST: db
+      PSQL_REP_DB_HOST: db
     ports:
       - 6432:6432
       - 6060:6060
-- 
GitLab