From 746b22393937f2d8c468e9ebe3200a996273a210 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Fri, 15 Sep 2023 15:49:13 -0500
Subject: [PATCH] only run sync if necessary, allow using session mode in
 digitalocean discovery mode

---
 .../modes/digitalocean_discovery/config.go    | 14 +++++++++---
 lib/middleware/middlewares/eqp/sync.go        | 22 +++++++++++++++++--
 2 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/lib/gat/modes/digitalocean_discovery/config.go b/lib/gat/modes/digitalocean_discovery/config.go
index 1e53b32d..ecd6e6c9 100644
--- a/lib/gat/modes/digitalocean_discovery/config.go
+++ b/lib/gat/modes/digitalocean_discovery/config.go
@@ -23,6 +23,7 @@ import (
 	"pggat/lib/gat/metrics"
 	"pggat/lib/gat/pool"
 	"pggat/lib/gat/pool/dialer"
+	"pggat/lib/gat/pool/pools/session"
 	"pggat/lib/gat/pool/pools/transaction"
 	"pggat/lib/gat/pool/recipe"
 	"pggat/lib/util/flip"
@@ -30,7 +31,8 @@ import (
 )
 
 type Config struct {
-	APIKey string `env:"PGGAT_DO_API_KEY"`
+	APIKey   string `env:"PGGAT_DO_API_KEY"`
+	PoolMode string `env:"PGGAT_POOL_MODE"`
 }
 
 func Load() (Config, error) {
@@ -122,7 +124,7 @@ func (T *Config) ListenAndServe() error {
 			}
 
 			for _, dbname := range cluster.DBNames {
-				poolOptions := transaction.Apply(pool.Options{
+				poolOptions := pool.Options{
 					Credentials:                creds,
 					ServerReconnectInitialTime: 5 * time.Second,
 					ServerReconnectMaxTime:     5 * time.Second,
@@ -134,7 +136,13 @@ func (T *Config) ListenAndServe() error {
 						strutil.MakeCIString("standard_conforming_strings"),
 						strutil.MakeCIString("application_name"),
 					},
-				})
+				}
+				if T.PoolMode == "session" {
+					poolOptions.ServerResetQuery = "DISCARD ALL"
+					poolOptions = session.Apply(poolOptions)
+				} else {
+					poolOptions = transaction.Apply(poolOptions)
+				}
 
 				p := pool.NewPool(poolOptions)
 
diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go
index 57b81839..818b8599 100644
--- a/lib/middleware/middlewares/eqp/sync.go
+++ b/lib/middleware/middlewares/eqp/sync.go
@@ -7,9 +7,15 @@ import (
 )
 
 func Sync(c *Client, server fed.ReadWriter, s *Server) error {
+	var needsBackendSync bool
+
 	// close all portals on server
 	// we close all because there won't be any for the normal case anyway, and it's hard to tell
 	// if a portal is accurate because the underlying prepared statement could have changed.
+	if len(s.state.portals) > 0 {
+		needsBackendSync = true
+	}
+
 	for name := range s.state.portals {
 		p := packets.Close{
 			Which:  'P',
@@ -41,6 +47,8 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error {
 		if err := server.WritePacket(p.IntoPacket()); err != nil {
 			return err
 		}
+
+		needsBackendSync = true
 	}
 
 	// parse all prepared statements that aren't on server
@@ -55,15 +63,25 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error {
 		if err := server.WritePacket(preparedStatement.Packet); err != nil {
 			return err
 		}
+
+		needsBackendSync = true
 	}
 
 	// bind all portals
+	if len(c.state.portals) > 0 {
+		needsBackendSync = true
+	}
+
 	for _, portal := range c.state.portals {
 		if err := server.WritePacket(portal.Packet); err != nil {
 			return err
 		}
 	}
 
-	_, err := backends.Sync(new(backends.Context), server)
-	return err
+	if needsBackendSync {
+		_, err := backends.Sync(new(backends.Context), server)
+		return err
+	}
+
+	return nil
 }
-- 
GitLab