From a10e92b19a1ee5989dd9da6c88b1127b1ab8fc8e Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 29 Aug 2023 16:33:58 -0500
Subject: [PATCH] session pool working fine

---
 cmd/cgat/main.go                  | 79 ++++++++-----------------------
 lib/gat/acceptor.go               | 48 +++++++++++++++++--
 lib/gat/gat.go                    | 51 --------------------
 lib/gat/modes/pgbouncer/config.go | 38 ++++++---------
 lib/gat/modes/zalando/config.go   | 14 ++----
 lib/gat/pool/pool.go              | 60 ++++++++++++-----------
 lib/gat/pool/recipe.go            |  2 +
 lib/gat/pools.go                  | 75 +++++++++++++++++++++++++++++
 8 files changed, 193 insertions(+), 174 deletions(-)
 delete mode 100644 lib/gat/gat.go
 create mode 100644 lib/gat/pools.go

diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go
index 42981009..e3235210 100644
--- a/cmd/cgat/main.go
+++ b/cmd/cgat/main.go
@@ -1,19 +1,14 @@
 package main
 
 import (
-	"crypto/tls"
 	"net/http"
 	_ "net/http/pprof"
+	"os"
 
 	"tuxpa.in/a/zlog/log"
 
-	"pggat2/lib/auth/credentials"
-	"pggat2/lib/bouncer"
-	"pggat2/lib/bouncer/backends/v0"
-	"pggat2/lib/bouncer/frontends/v0"
-	"pggat2/lib/gat"
-	"pggat2/lib/gat/pool"
-	"pggat2/lib/gat/pool/pools/session"
+	"pggat2/lib/gat/modes/pgbouncer"
+	"pggat2/lib/gat/modes/zalando"
 )
 
 func main() {
@@ -23,56 +18,9 @@ func main() {
 
 	log.Printf("Starting pggat...")
 
-	g := new(gat.Gat)
-	g.TestPool = session.NewPool(pool.Options{
-		Credentials: credentials.Cleartext{
-			Username: "postgres",
-			Password: "password",
-		},
-	})
-	g.TestPool.AddRecipe("test", pool.Recipe{
-		Dialer: pool.NetDialer{
-			Network: "tcp",
-			Address: "localhost:5432",
-
-			AcceptOptions: backends.AcceptOptions{
-				SSLMode: bouncer.SSLModeAllow,
-				SSLConfig: &tls.Config{
-					InsecureSkipVerify: true,
-				},
-				Credentials: credentials.Cleartext{
-					Username: "postgres",
-					Password: "password",
-				},
-				Database: "postgres",
-			},
-		},
-		MinConnections: 1,
-		MaxConnections: 1,
-	})
-	err := gat.ListenAndServe("tcp", ":6432", frontends.AcceptOptions{}, g)
-	if err != nil {
-		panic(err)
-	}
-
-	/*
-		if len(os.Args) == 2 {
-			log.Printf("running in pgbouncer compatibility mode")
-			conf, err := pgbouncer.Load(os.Args[1])
-			if err != nil {
-				panic(err)
-			}
-
-			err = conf.ListenAndServe()
-			if err != nil {
-				panic(err)
-			}
-			return
-		}
-
-		log.Printf("running in zalando compatibility mode")
-
-		conf, err := zalando.Load()
+	if len(os.Args) == 2 {
+		log.Printf("running in pgbouncer compatibility mode")
+		conf, err := pgbouncer.Load(os.Args[1])
 		if err != nil {
 			panic(err)
 		}
@@ -81,5 +29,18 @@ func main() {
 		if err != nil {
 			panic(err)
 		}
-	*/
+		return
+	}
+
+	log.Printf("running in zalando compatibility mode")
+
+	conf, err := zalando.Load()
+	if err != nil {
+		panic(err)
+	}
+
+	err = conf.ListenAndServe()
+	if err != nil {
+		panic(err)
+	}
 }
diff --git a/lib/gat/acceptor.go b/lib/gat/acceptor.go
index 7dfa49b3..53c718e0 100644
--- a/lib/gat/acceptor.go
+++ b/lib/gat/acceptor.go
@@ -3,6 +3,7 @@ package gat
 import (
 	"net"
 
+	"pggat2/lib/auth"
 	"pggat2/lib/bouncer/frontends/v0"
 	"pggat2/lib/zap"
 )
@@ -37,22 +38,59 @@ func Listen(network, address string, options frontends.AcceptOptions) (Acceptor,
 	}, nil
 }
 
-func Serve(acceptor Acceptor, gat *Gat) error {
+func serve(client zap.Conn, acceptParams frontends.AcceptParams, pools Pools) error {
+	defer func() {
+		_ = client.Close()
+	}()
+
+	if acceptParams.CancelKey != [8]byte{} {
+		p := pools.LookupKey(acceptParams.CancelKey)
+		if p == nil {
+			return nil
+		}
+		return p.Cancel(acceptParams.CancelKey)
+	}
+
+	p := pools.Lookup(acceptParams.User, acceptParams.Database)
+
+	var credentials auth.Credentials
+	if p != nil {
+		credentials = p.GetCredentials()
+	}
+
+	authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{
+		Credentials: credentials,
+	})
+	if err != nil {
+		return err
+	}
+
+	if p == nil {
+		return nil
+	}
+
+	pools.RegisterKey(authParams.BackendKey, acceptParams.User, acceptParams.Database)
+	defer pools.UnregisterKey(authParams.BackendKey)
+
+	return p.Serve(client, acceptParams, authParams)
+}
+
+func Serve(acceptor Acceptor, pools Pools) error {
 	for {
-		conn, params, err := acceptor.Accept()
+		conn, acceptParams, err := acceptor.Accept()
 		if err != nil {
 			continue
 		}
 		go func() {
-			_ = gat.Serve(conn, params)
+			_ = serve(conn, acceptParams, pools)
 		}()
 	}
 }
 
-func ListenAndServe(network, address string, options frontends.AcceptOptions, gat *Gat) error {
+func ListenAndServe(network, address string, options frontends.AcceptOptions, pools Pools) error {
 	listener, err := Listen(network, address, options)
 	if err != nil {
 		return err
 	}
-	return Serve(listener, gat)
+	return Serve(listener, pools)
 }
diff --git a/lib/gat/gat.go b/lib/gat/gat.go
deleted file mode 100644
index 8cadbb01..00000000
--- a/lib/gat/gat.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package gat
-
-import (
-	"pggat2/lib/auth"
-	"pggat2/lib/bouncer/frontends/v0"
-	"pggat2/lib/gat/pool"
-	"pggat2/lib/zap"
-)
-
-type Gat struct {
-	TestPool *pool.Pool
-}
-
-func (T *Gat) Serve(client zap.Conn, acceptParams frontends.AcceptParams) error {
-	defer func() {
-		_ = client.Close()
-	}()
-
-	if acceptParams.CancelKey != [8]byte{} {
-		// TODO(garet) execute cancel
-		return nil
-	}
-
-	p, err := T.GetPool(acceptParams.User, acceptParams.Database)
-	if err != nil {
-		return err
-	}
-
-	var credentials auth.Credentials
-	if p != nil {
-		credentials = p.GetCredentials()
-	}
-
-	authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{
-		Credentials: credentials,
-	})
-	if err != nil {
-		return err
-	}
-
-	if p == nil {
-		return nil
-	}
-
-	return p.Serve(client, acceptParams, authParams)
-}
-
-func (T *Gat) GetPool(user, database string) (*pool.Pool, error) {
-	return T.TestPool, nil
-	return nil, nil // TODO(garet)
-}
diff --git a/lib/gat/modes/pgbouncer/config.go b/lib/gat/modes/pgbouncer/config.go
index aad103ff..dfd7eca9 100644
--- a/lib/gat/modes/pgbouncer/config.go
+++ b/lib/gat/modes/pgbouncer/config.go
@@ -10,6 +10,7 @@ import (
 
 	"tuxpa.in/a/zlog/log"
 
+	"pggat2/lib/bouncer"
 	"pggat2/lib/bouncer/backends/v0"
 	"pggat2/lib/bouncer/frontends/v0"
 	"pggat2/lib/gat/pool"
@@ -44,17 +45,6 @@ const (
 	AuthTypePam         AuthType = "pam"
 )
 
-type SSLMode string
-
-const (
-	SSLModeDisable    SSLMode = "disable"
-	SSLModeAllow      SSLMode = "allow"
-	SSLModePrefer     SSLMode = "prefer"
-	SSLModeRequire    SSLMode = "require"
-	SSLModeVerifyCa   SSLMode = "verify-ca"
-	SSLModeVerifyFull SSLMode = "verify-full"
-)
-
 type TLSProtocol string
 
 const (
@@ -130,7 +120,7 @@ type PgBouncer struct {
 	DnsNxdomainTtl          float64            `ini:"dns_nxdomain_ttl"`
 	DnsZoneCheckPeriod      float64            `ini:"dns_zone_check_period"`
 	ResolvConf              string             `ini:"resolv.conf"`
-	ClientTLSSSLMode        SSLMode            `ini:"client_tls_sslmode"`
+	ClientTLSSSLMode        bouncer.SSLMode    `ini:"client_tls_sslmode"`
 	ClientTLSKeyFile        string             `ini:"client_tls_key_file"`
 	ClientTLSCertFile       string             `ini:"client_tls_cert_file"`
 	ClientTLSCaFile         string             `ini:"client_tls_ca_file"`
@@ -138,7 +128,7 @@ type PgBouncer struct {
 	ClientTLSCiphers        []TLSCipher        `ini:"client_tls_ciphers"`
 	ClientTLSECDHCurve      TLSECDHCurve       `ini:"client_tls_ecdhcurve"`
 	ClientTLSDHEParams      TLSDHEParams       `ini:"client_tls_dheparams"`
-	ServerTLSSSLMode        SSLMode            `ini:"server_tls_sslmode"`
+	ServerTLSSSLMode        bouncer.SSLMode    `ini:"server_tls_sslmode"`
 	ServerTLSCaFile         string             `ini:"server_tls_ca_file"`
 	ServerTLSKeyFile        string             `ini:"server_tls_key_file"`
 	ServerTLSCertFile       string             `ini:"server_tls_cert_file"`
@@ -229,7 +219,7 @@ var Default = Config{
 		AutodbIdleTimeout:    3600.0,
 		DnsMaxTtl:            15.0,
 		DnsNxdomainTtl:       15.0,
-		ClientTLSSSLMode:     SSLModeDisable,
+		ClientTLSSSLMode:     bouncer.SSLModeDisable,
 		ClientTLSProtocols: []TLSProtocol{
 			TLSProtocolSecure,
 		},
@@ -237,7 +227,7 @@ var Default = Config{
 			"fast",
 		},
 		ClientTLSECDHCurve: "auto",
-		ServerTLSSSLMode:   SSLModePrefer,
+		ServerTLSSSLMode:   bouncer.SSLModePrefer,
 		ServerTLSProtocols: []TLSProtocol{
 			TLSProtocolSecure,
 		},
@@ -282,7 +272,7 @@ func (T *Config) ListenAndServe() error {
 		AllowedStartupOptions: allowedStartupParameters,
 	}
 
-	g := new(gat.Gat)
+	pools := new(gat.PoolsMap)
 
 	var authFile map[string]string
 	if T.PgBouncer.AuthFile != "" {
@@ -302,10 +292,6 @@ func (T *Config) ListenAndServe() error {
 			Username: name,
 			Password: authFile[name], // TODO(garet) md5 and sasl
 		}
-		/* TODO(garet)
-		u := gat.NewUser(creds)
-		g.AddUser(u)
-		*/
 
 		for dbname, db := range T.Databases {
 			// filter out dbs specific to users
@@ -329,7 +315,9 @@ func (T *Config) ListenAndServe() error {
 			}
 
 			poolOptions := pool.Options{
+				Credentials:       creds,
 				TrackedParameters: trackedParameters,
+				ServerResetQuery:  T.PgBouncer.ServerResetQuery,
 				ServerIdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)),
 			}
 
@@ -338,12 +326,16 @@ func (T *Config) ListenAndServe() error {
 			case PoolModeSession:
 				p = session.NewPool(poolOptions)
 			case PoolModeTransaction:
+				if T.PgBouncer.ServerResetQueryAlways == 0 {
+					poolOptions.ServerResetQuery = ""
+				}
+				panic("transaction mode not implemented yet")
 				// TODO(garet)
 			default:
 				return errors.New("unsupported pool mode")
 			}
 
-			// TODO(garet) add to gat
+			pools.Add(name, dbname, p)
 
 			if db.Host == "" {
 				// connect over unix socket
@@ -402,7 +394,7 @@ func (T *Config) ListenAndServe() error {
 
 			log.Printf("listening on %s", listen)
 
-			return gat.ListenAndServe("tcp", listen, acceptOptions, g)
+			return gat.ListenAndServe("tcp", listen, acceptOptions, pools)
 		})
 	}
 
@@ -418,7 +410,7 @@ func (T *Config) ListenAndServe() error {
 
 		log.Printf("listening on unix:%s", dir)
 
-		return gat.ListenAndServe("unix", dir, acceptOptions, g)
+		return gat.ListenAndServe("unix", dir, acceptOptions, pools)
 	})
 
 	return bank.Wait()
diff --git a/lib/gat/modes/zalando/config.go b/lib/gat/modes/zalando/config.go
index 806d43ec..1a2a2739 100644
--- a/lib/gat/modes/zalando/config.go
+++ b/lib/gat/modes/zalando/config.go
@@ -46,26 +46,22 @@ func Load() (Config, error) {
 }
 
 func (T *Config) ListenAndServe() error {
-	g := new(gat.Gat)
+	pools := new(gat.PoolsMap)
 
 	creds := credentials.Cleartext{
 		Username: T.PGUser,
 		Password: T.PGPassword,
 	}
 
-	/* TODO(garet)
-	user := gat.NewUser(creds)
-	g.AddUser(user)
-	*/
-
 	var p *pool.Pool
 	if T.PoolerMode == "transaction" {
-		// p = transaction.NewPool(pool.Options{})
+		panic("transaction mode not implemented yet")
+		// TODO(garet) p = transaction.NewPool(pool.Options{})
 	} else {
 		p = session.NewPool(pool.Options{})
 	}
 
-	// TODO(garet) add to gat
+	pools.Add(T.PGUser, "test", p)
 
 	p.AddRecipe("zalando", pool.Recipe{
 		Dialer: pool.NetDialer{
@@ -87,7 +83,7 @@ func (T *Config) ListenAndServe() error {
 
 		log.Printf("listening on %s", listen)
 
-		return gat.ListenAndServe("tcp", listen, frontends.AcceptOptions{}, g)
+		return gat.ListenAndServe("tcp", listen, frontends.AcceptOptions{}, pools)
 	})
 
 	return bank.Wait()
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index 0654b39f..a7a2313b 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -39,17 +39,24 @@ type poolRecipe struct {
 type Pool struct {
 	options Options
 
-	maxServers int
-	recipes    map[string]*poolRecipe
-	servers    map[uuid.UUID]poolServer
-	clients    map[uuid.UUID]zap.Conn
-	mu         sync.Mutex
+	recipes map[string]*poolRecipe
+	servers map[uuid.UUID]poolServer
+	clients map[uuid.UUID]zap.Conn
+	mu      sync.Mutex
 }
 
 func NewPool(options Options) *Pool {
-	return &Pool{
+	p := &Pool{
 		options: options,
 	}
+
+	if options.ServerIdleTimeout != 0 {
+		go func() {
+			// TODO(garet) check pool for idle servers
+		}()
+	}
+
+	return p
 }
 
 func (T *Pool) GetCredentials() auth.Credentials {
@@ -62,6 +69,7 @@ func (T *Pool) _scaleUpRecipe(name string) {
 	server, params, err := r.recipe.Dialer.Dial()
 	if err != nil {
 		log.Printf("failed to dial server: %v", err)
+		return
 	}
 
 	serverID := uuid.New()
@@ -103,7 +111,6 @@ func (T *Pool) AddRecipe(name string, recipe Recipe) {
 	if T.recipes == nil {
 		T.recipes = make(map[string]*poolRecipe)
 	}
-	T.maxServers += recipe.MaxConnections
 	T.recipes[name] = &poolRecipe{
 		recipe: recipe,
 		count:  0,
@@ -118,9 +125,6 @@ func (T *Pool) RemoveRecipe(name string) {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
-	if r, ok := T.recipes[name]; ok {
-		T.maxServers -= r.count
-	}
 	delete(T.recipes, name)
 
 	// close all servers with this recipe
@@ -138,11 +142,13 @@ func (T *Pool) scaleUp() {
 	defer T.mu.Unlock()
 
 	for name, r := range T.recipes {
-		if r.count < r.recipe.MaxConnections {
+		if r.recipe.MaxConnections == 0 || r.count < r.recipe.MaxConnections {
 			T._scaleUpRecipe(name)
 			return
 		}
 	}
+
+	log.Println("warning: tried to scale up pool but no space was available")
 }
 
 func (T *Pool) syncInitialParameters(
@@ -185,21 +191,16 @@ func (T *Pool) syncInitialParameters(
 			continue
 		}
 
-		if slices.Contains(T.options.TrackedParameters, key) {
-			serverErr = backends.ResetParameter(new(backends.Context), server, key)
-			if serverErr != nil {
-				return
-			}
-		} else {
-			// send to client
-			p := packets.ParameterStatus{
-				Key:   key.String(),
-				Value: value,
-			}
-			clientErr = client.WritePacket(p.IntoPacket())
-			if clientErr != nil {
-				return
-			}
+		// Don't need to run reset on server because it will reset it to the initial value
+
+		// send to client
+		p := packets.ParameterStatus{
+			Key:   key.String(),
+			Value: value,
+		}
+		clientErr = client.WritePacket(p.IntoPacket())
+		if clientErr != nil {
+			return
 		}
 	}
 
@@ -270,7 +271,7 @@ func (T *Pool) Serve(
 				server.eqpServer.SetClient(eqpClient)
 			}
 		}
-		if clientErr != nil && serverErr != nil {
+		if clientErr == nil && serverErr == nil {
 			clientErr, serverErr = bouncers.Bounce(client, server.conn, packet)
 		}
 		if serverErr != nil {
@@ -352,3 +353,8 @@ func (T *Pool) removeServer(serverID uuid.UUID) {
 
 	T._removeServer(serverID)
 }
+
+func (T *Pool) Cancel(key [8]byte) error {
+	// TODO(garet) implement cancel
+	return nil
+}
diff --git a/lib/gat/pool/recipe.go b/lib/gat/pool/recipe.go
index 66260a8f..e12c8a24 100644
--- a/lib/gat/pool/recipe.go
+++ b/lib/gat/pool/recipe.go
@@ -3,5 +3,7 @@ package pool
 type Recipe struct {
 	Dialer         Dialer
 	MinConnections int
+	// MaxConnections is the max number of active server connections for this recipe.
+	// 0 = unlimited
 	MaxConnections int
 }
diff --git a/lib/gat/pools.go b/lib/gat/pools.go
new file mode 100644
index 00000000..6cb9e834
--- /dev/null
+++ b/lib/gat/pools.go
@@ -0,0 +1,75 @@
+package gat
+
+import (
+	"pggat2/lib/gat/pool"
+	"pggat2/lib/util/maps"
+)
+
+type Pools interface {
+	Lookup(user, database string) *pool.Pool
+
+	// Key based lookup functions (for cancellation)
+
+	RegisterKey(key [8]byte, user, database string)
+	UnregisterKey(key [8]byte)
+
+	LookupKey(key [8]byte) *pool.Pool
+}
+
+type mapKey struct {
+	User     string
+	Database string
+}
+
+type PoolsMap struct {
+	pools maps.RWLocked[mapKey, *pool.Pool]
+	keys  maps.RWLocked[[8]byte, mapKey]
+}
+
+func (T *PoolsMap) Add(user, database string, pool *pool.Pool) {
+	T.pools.Store(mapKey{
+		User:     user,
+		Database: database,
+	}, pool)
+}
+
+func (T *PoolsMap) Remove(user, database string) {
+	T.pools.Delete(mapKey{
+		User:     user,
+		Database: database,
+	})
+}
+
+func (T *PoolsMap) Lookup(user, database string) *pool.Pool {
+	p, _ := T.pools.Load(mapKey{
+		User:     user,
+		Database: database,
+	})
+	return p
+}
+
+// key based lookup funcs
+
+func (T *PoolsMap) RegisterKey(key [8]byte, user, database string) {
+	T.keys.Store(key, mapKey{
+		User:     user,
+		Database: database,
+	})
+}
+
+func (T *PoolsMap) UnregisterKey(key [8]byte) {
+	T.keys.Delete(key)
+}
+
+func (T *PoolsMap) LookupKey(key [8]byte) *pool.Pool {
+	m, ok := T.keys.Load(key)
+	if !ok {
+		return nil
+	}
+	p, ok := T.pools.Load(m)
+	if !ok {
+		T.keys.Delete(key)
+		return nil
+	}
+	return p
+}
-- 
GitLab