diff --git a/lib/gat/modes/digitalocean_discovery/config.go b/lib/gat/modes/digitalocean_discovery/config.go index 1e53b32d4438e1235c7963f7284afc0077c248c6..ecd6e6c90d064f0d51d79027765328e60a732976 100644 --- a/lib/gat/modes/digitalocean_discovery/config.go +++ b/lib/gat/modes/digitalocean_discovery/config.go @@ -23,6 +23,7 @@ import ( "pggat/lib/gat/metrics" "pggat/lib/gat/pool" "pggat/lib/gat/pool/dialer" + "pggat/lib/gat/pool/pools/session" "pggat/lib/gat/pool/pools/transaction" "pggat/lib/gat/pool/recipe" "pggat/lib/util/flip" @@ -30,7 +31,8 @@ import ( ) type Config struct { - APIKey string `env:"PGGAT_DO_API_KEY"` + APIKey string `env:"PGGAT_DO_API_KEY"` + PoolMode string `env:"PGGAT_POOL_MODE"` } func Load() (Config, error) { @@ -122,7 +124,7 @@ func (T *Config) ListenAndServe() error { } for _, dbname := range cluster.DBNames { - poolOptions := transaction.Apply(pool.Options{ + poolOptions := pool.Options{ Credentials: creds, ServerReconnectInitialTime: 5 * time.Second, ServerReconnectMaxTime: 5 * time.Second, @@ -134,7 +136,13 @@ func (T *Config) ListenAndServe() error { strutil.MakeCIString("standard_conforming_strings"), strutil.MakeCIString("application_name"), }, - }) + } + if T.PoolMode == "session" { + poolOptions.ServerResetQuery = "DISCARD ALL" + poolOptions = session.Apply(poolOptions) + } else { + poolOptions = transaction.Apply(poolOptions) + } p := pool.NewPool(poolOptions) diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go index 57b818396db2faa79e0b92cfbaf180d5c9b3e9a6..818b859979c3dedfcfe83e58975dbf33c06c3384 100644 --- a/lib/middleware/middlewares/eqp/sync.go +++ b/lib/middleware/middlewares/eqp/sync.go @@ -7,9 +7,15 @@ import ( ) func Sync(c *Client, server fed.ReadWriter, s *Server) error { + var needsBackendSync bool + // close all portals on server // we close all because there won't be any for the normal case anyway, and it's hard to tell // if a portal is accurate because the underlying prepared statement could have changed. + if len(s.state.portals) > 0 { + needsBackendSync = true + } + for name := range s.state.portals { p := packets.Close{ Which: 'P', @@ -41,6 +47,8 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { if err := server.WritePacket(p.IntoPacket()); err != nil { return err } + + needsBackendSync = true } // parse all prepared statements that aren't on server @@ -55,15 +63,25 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { if err := server.WritePacket(preparedStatement.Packet); err != nil { return err } + + needsBackendSync = true } // bind all portals + if len(c.state.portals) > 0 { + needsBackendSync = true + } + for _, portal := range c.state.portals { if err := server.WritePacket(portal.Packet); err != nil { return err } } - _, err := backends.Sync(new(backends.Context), server) - return err + if needsBackendSync { + _, err := backends.Sync(new(backends.Context), server) + return err + } + + return nil }