From 20ab5b9a0323914127049b022561a6a59220e4a5 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 30 Aug 2023 18:35:59 -0500
Subject: [PATCH] fix some auth and connection bugs

---
 lib/auth/credentials/cleartext.go |  7 +------
 lib/bouncer/backends/v0/accept.go |  1 +
 lib/bouncer/sslmode.go            |  2 +-
 lib/gat/modes/pgbouncer/config.go |  2 ++
 lib/gat/modes/pgbouncer/pools.go  | 35 ++++++++++++++++---------------
 lib/gat/modes/zalando/config.go   |  2 +-
 pgbouncer.ini                     |  3 +--
 7 files changed, 25 insertions(+), 27 deletions(-)

diff --git a/lib/auth/credentials/cleartext.go b/lib/auth/credentials/cleartext.go
index 7e51ebdc..fdcb83ff 100644
--- a/lib/auth/credentials/cleartext.go
+++ b/lib/auth/credentials/cleartext.go
@@ -33,8 +33,8 @@ func (T Cleartext) VerifyCleartext(value string) error {
 
 func (T Cleartext) EncodeMD5(salt [4]byte) string {
 	hash := md5.New()
-	hash.Write([]byte(T.Username))
 	hash.Write([]byte(T.Password))
+	hash.Write([]byte(T.Username))
 	sum1 := hash.Sum(nil)
 	hexEncoded := make([]byte, hex.EncodedLen(len(sum1)))
 	hex.Encode(hexEncoded, sum1)
@@ -83,11 +83,6 @@ func MakeCleartextScramEncoder(username, password string, hashGenerator scram.Ha
 }
 
 func (T CleartextScramEncoder) Write(bytes []byte) ([]byte, error) {
-	if bytes == nil {
-		// initial response
-		return nil, nil
-	}
-
 	msg, err := T.conversation.Step(string(bytes))
 	if err != nil {
 		return nil, err
diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go
index 741b66b8..125b3b3b 100644
--- a/lib/bouncer/backends/v0/accept.go
+++ b/lib/bouncer/backends/v0/accept.go
@@ -146,6 +146,7 @@ func startup0(server zap.Conn, creds auth.Credentials) (done bool, err error) {
 				err = ErrBadFormat
 				return
 			}
+
 			c, ok := creds.(auth.MD5)
 			if !ok {
 				return false, auth.ErrMethodNotSupported
diff --git a/lib/bouncer/sslmode.go b/lib/bouncer/sslmode.go
index abcd2be9..cb8b63d5 100644
--- a/lib/bouncer/sslmode.go
+++ b/lib/bouncer/sslmode.go
@@ -22,7 +22,7 @@ func (T SSLMode) ShouldAttempt() bool {
 
 func (T SSLMode) IsRequired() bool {
 	switch T {
-	case SSLModeDisable, SSLModeAllow, SSLModeRequire, "":
+	case SSLModeDisable, SSLModeAllow, SSLModePrefer, "":
 		return false
 	default:
 		return true
diff --git a/lib/gat/modes/pgbouncer/config.go b/lib/gat/modes/pgbouncer/config.go
index 4de99906..548d0115 100644
--- a/lib/gat/modes/pgbouncer/config.go
+++ b/lib/gat/modes/pgbouncer/config.go
@@ -261,6 +261,8 @@ func (T *Config) ListenAndServe() error {
 	allowedStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...)
 
 	acceptOptions := frontends.AcceptOptions{
+		SSLRequired: T.PgBouncer.ClientTLSSSLMode.IsRequired(),
+		// TODO(garet) SSL Certificates
 		AllowedStartupOptions: allowedStartupParameters,
 	}
 
diff --git a/lib/gat/modes/pgbouncer/pools.go b/lib/gat/modes/pgbouncer/pools.go
index 7ec25396..a12e13fb 100644
--- a/lib/gat/modes/pgbouncer/pools.go
+++ b/lib/gat/modes/pgbouncer/pools.go
@@ -1,6 +1,7 @@
 package pgbouncer
 
 import (
+	"crypto/tls"
 	"net"
 	"strconv"
 	"time"
@@ -14,6 +15,7 @@ import (
 	"pggat2/lib/gat/pool/pools/session"
 	"pggat2/lib/gat/pool/pools/transaction"
 	"pggat2/lib/psql"
+	"pggat2/lib/util/maps"
 	"pggat2/lib/util/strutil"
 	"pggat2/lib/zap"
 )
@@ -31,8 +33,8 @@ type poolKey struct {
 type Pools struct {
 	Config *Config
 
-	pools map[poolKey]*pool.Pool
-	keys  map[[8]byte]*pool.Pool
+	pools maps.RWLocked[poolKey, *pool.Pool]
+	keys  maps.RWLocked[[8]byte, *pool.Pool]
 }
 
 func NewPools(config *Config) (*Pools, error) {
@@ -48,7 +50,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool {
 		User:     user,
 		Database: database,
 	}
-	p := T.pools[key]
+	p, _ := T.pools.Load(key)
 	if p != nil {
 		return p
 	}
@@ -110,9 +112,9 @@ func (T *Pools) Lookup(user, database string) *pool.Pool {
 		Password: password, // TODO(garet) md5 and sasl
 	}
 
-	backendDatabase := database
-	if db.DBName != "" {
-		backendDatabase = db.DBName
+	backendDatabase := db.DBName
+	if backendDatabase == "" {
+		backendDatabase = database
 	}
 
 	configUser := T.Config.Users[user]
@@ -152,13 +154,10 @@ func (T *Pools) Lookup(user, database string) *pool.Pool {
 		return nil
 	}
 
-	if T.pools == nil {
-		T.pools = make(map[poolKey]*pool.Pool)
-	}
-	T.pools[poolKey{
+	T.pools.Store(poolKey{
 		User:     user,
 		Database: database,
-	}] = p
+	}, p)
 
 	if db.Host == "" {
 		// connect over unix socket
@@ -182,6 +181,10 @@ func (T *Pools) Lookup(user, database string) *pool.Pool {
 			Network: "tcp",
 			Address: address,
 			AcceptOptions: backends.AcceptOptions{
+				SSLMode: T.Config.PgBouncer.ServerTLSSSLMode,
+				SSLConfig: &tls.Config{
+					InsecureSkipVerify: true, // TODO(garet)
+				},
 				Credentials:       creds,
 				Database:          backendDatabase,
 				StartupParameters: db.StartupParameters,
@@ -210,18 +213,16 @@ func (T *Pools) RegisterKey(key [8]byte, user, database string) {
 	if p == nil {
 		return
 	}
-	if T.keys == nil {
-		T.keys = make(map[[8]byte]*pool.Pool)
-	}
-	T.keys[key] = p
+	T.keys.Store(key, p)
 }
 
 func (T *Pools) UnregisterKey(key [8]byte) {
-	delete(T.keys, key)
+	T.keys.Delete(key)
 }
 
 func (T *Pools) LookupKey(key [8]byte) *pool.Pool {
-	return T.keys[key]
+	p, _ := T.keys.Load(key)
+	return p
 }
 
 var _ gat.Pools = (*Pools)(nil)
diff --git a/lib/gat/modes/zalando/config.go b/lib/gat/modes/zalando/config.go
index f0705c49..24ba25e7 100644
--- a/lib/gat/modes/zalando/config.go
+++ b/lib/gat/modes/zalando/config.go
@@ -57,7 +57,7 @@ func (T *Config) ListenAndServe() error {
 	}
 	pgb.PgBouncer.AdminUsers = []string{T.PGUser}
 	pgb.PgBouncer.AuthQuery = fmt.Sprintf("SELECT * FROM %s.user_lookup($1)", T.PGSchema)
-	pgb.PgBouncer.LogFile = "/var/olg/pgbouncer/pgbouncer.log"
+	pgb.PgBouncer.LogFile = "/var/log/pgbouncer/pgbouncer.log"
 	pgb.PgBouncer.PidFile = "/var/run/pgbouncer/pgbouncer.pid"
 
 	pgb.PgBouncer.ServerTLSSSLMode = bouncer.SSLModeRequire
diff --git a/pgbouncer.ini b/pgbouncer.ini
index 6c175d6f..8a6df6f0 100644
--- a/pgbouncer.ini
+++ b/pgbouncer.ini
@@ -8,5 +8,4 @@ track_extra_parameters = IntervalStyle, session_authorization, default_transacti
 postgres =
 
 [databases]
-regression = host=localhost datestyle=Postgres,MDY timezone=PST8PDT
-postgres = host=localhost
+* = host=localhost datestyle=Postgres,MDY timezone=PST8PDT
-- 
GitLab