From eae87249ad4bf972b4dc5ed0ad7b2d3956a5b7d8 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Mon, 2 Oct 2023 16:48:38 -0500
Subject: [PATCH] almost working

---
 lib/gat/handlers/discovery/config.go |   4 +-
 lib/gat/handlers/discovery/module.go |  18 +-
 lib/gat/handlers/pool/config.go      |   4 +-
 lib/gat/handlers/pool/module.go      |  20 ++-
 test/tester_test.go                  | 244 +++++++++++++++++++--------
 5 files changed, 199 insertions(+), 91 deletions(-)

diff --git a/lib/gat/handlers/discovery/config.go b/lib/gat/handlers/discovery/config.go
index d21a5533..b2712561 100644
--- a/lib/gat/handlers/discovery/config.go
+++ b/lib/gat/handlers/discovery/config.go
@@ -15,8 +15,8 @@ type Config struct {
 
 	Pooler json.RawMessage `json:"pooler" caddy:"namespace=pggat.poolers inline_key=pooler"`
 
-	ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode"`
-	ServerSSL     json.RawMessage `json:"server_ssl" caddy:"namespace=pggat.ssl.clients inline_key=provider"`
+	ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode,omitempty"`
+	ServerSSL     json.RawMessage `json:"server_ssl,omitempty" caddy:"namespace=pggat.ssl.clients inline_key=provider"`
 
 	ServerStartupParameters map[string]string `json:"server_startup_parameters,omitempty"`
 }
diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go
index dfb28631..bd6614d2 100644
--- a/lib/gat/handlers/discovery/module.go
+++ b/lib/gat/handlers/discovery/module.go
@@ -1,6 +1,7 @@
 package discovery
 
 import (
+	"crypto/tls"
 	"fmt"
 	"sync"
 	"time"
@@ -30,8 +31,8 @@ type Module struct {
 
 	discoverer Discoverer
 
-	pooler gat.Pooler
-	ssl    gat.SSLClient
+	pooler    gat.Pooler
+	sslConfig *tls.Config
 
 	serverStartupParameters map[strutil.CIString]string
 
@@ -77,7 +78,8 @@ func (T *Module) Provision(ctx caddy.Context) error {
 		if err != nil {
 			return fmt.Errorf("loading ssl module: %v", err)
 		}
-		T.ssl = val.(gat.SSLClient)
+		ssl := val.(gat.SSLClient)
+		T.sslConfig = ssl.ClientTLSConfig()
 	}
 	T.serverStartupParameters = make(map[strutil.CIString]string, len(T.ServerStartupParameters))
 	for key, value := range T.ServerStartupParameters {
@@ -228,7 +230,7 @@ func (T *Module) replacePrimary(users []User, databases []string, endpoint Endpo
 				Credentials:       primaryCreds,
 				Database:          database,
 				SSLMode:           T.ServerSSLMode,
-				SSLConfig:         T.ssl.ClientTLSConfig(),
+				SSLConfig:         T.sslConfig,
 				StartupParameters: T.serverStartupParameters,
 			}
 
@@ -263,7 +265,7 @@ func (T *Module) addReplicas(replicas map[string]Endpoint, users []User, databas
 					Credentials:       primaryCreds,
 					Database:          database,
 					SSLMode:           T.ServerSSLMode,
-					SSLConfig:         T.ssl.ClientTLSConfig(),
+					SSLConfig:         T.sslConfig,
 					StartupParameters: T.serverStartupParameters,
 				}
 				replicaPool.AddRecipe(id, recipe.NewRecipe(recipe.Config{
@@ -302,7 +304,7 @@ func (T *Module) addReplica(users []User, databases []string, id string, endpoin
 				Credentials:       primaryCreds,
 				Database:          database,
 				SSLMode:           T.ServerSSLMode,
-				SSLConfig:         T.ssl.ClientTLSConfig(),
+				SSLConfig:         T.sslConfig,
 				StartupParameters: T.serverStartupParameters,
 			}
 			p.AddRecipe(id, recipe.NewRecipe(recipe.Config{
@@ -334,7 +336,7 @@ func (T *Module) addUser(primaryEndpoint Endpoint, replicas map[string]Endpoint,
 			Credentials:       primaryCreds,
 			Database:          database,
 			SSLMode:           T.ServerSSLMode,
-			SSLConfig:         T.ssl.ClientTLSConfig(),
+			SSLConfig:         T.sslConfig,
 			StartupParameters: T.serverStartupParameters,
 		}
 
@@ -393,7 +395,7 @@ func (T *Module) addDatabase(primaryEndpoint Endpoint, replicas map[string]Endpo
 			Credentials:       primaryCreds,
 			Database:          database,
 			SSLMode:           T.ServerSSLMode,
-			SSLConfig:         T.ssl.ClientTLSConfig(),
+			SSLConfig:         T.sslConfig,
 			StartupParameters: T.serverStartupParameters,
 		}
 
diff --git a/lib/gat/handlers/pool/config.go b/lib/gat/handlers/pool/config.go
index af6480cd..5a2fc336 100644
--- a/lib/gat/handlers/pool/config.go
+++ b/lib/gat/handlers/pool/config.go
@@ -11,8 +11,8 @@ type Config struct {
 
 	// Server connect options
 	ServerAddress string          `jsonn:"server_address"`
-	ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode"`
-	ServerSSL     json.RawMessage `json:"server_ssl" caddy:"namespace=pggat.ssl.clients inline_key=provider"`
+	ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode,omitempty"`
+	ServerSSL     json.RawMessage `json:"server_ssl,omitempty" caddy:"namespace=pggat.ssl.clients inline_key=provider"`
 
 	// Server routing options
 	ServerUsername          string            `json:"server_username"`
diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go
index d6ac91ef..61b0085a 100644
--- a/lib/gat/handlers/pool/module.go
+++ b/lib/gat/handlers/pool/module.go
@@ -1,12 +1,14 @@
 package pool_handler
 
 import (
+	"crypto/tls"
 	"fmt"
 	"strings"
 
 	"github.com/caddyserver/caddy/v2"
 
 	"gfx.cafe/gfx/pggat/lib/auth/credentials"
+	"gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0"
 	"gfx.cafe/gfx/pggat/lib/fed"
 	"gfx.cafe/gfx/pggat/lib/gat"
 	"gfx.cafe/gfx/pggat/lib/gat/metrics"
@@ -40,11 +42,15 @@ func (T *Module) Provision(ctx caddy.Context) error {
 	}
 	pooler := val.(gat.Pooler)
 
-	val, err = ctx.LoadModule(T, "ServerSSL")
-	if err != nil {
-		return fmt.Errorf("loading ssl module: %v", err)
+	var sslConfig *tls.Config
+	if T.ServerSSL != nil {
+		val, err = ctx.LoadModule(T, "ServerSSL")
+		if err != nil {
+			return fmt.Errorf("loading ssl module: %v", err)
+		}
+		ssl := val.(gat.SSLClient)
+		sslConfig = ssl.ClientTLSConfig()
 	}
-	ssl := val.(gat.SSLClient)
 
 	creds := credentials.FromString(T.ServerUsername, T.ServerPassword)
 	startupParameters := make(map[strutil.CIString]string, len(T.ServerStartupParameters))
@@ -63,7 +69,7 @@ func (T *Module) Provision(ctx caddy.Context) error {
 		Network:           network,
 		Address:           T.ServerAddress,
 		SSLMode:           T.ServerSSLMode,
-		SSLConfig:         ssl.ClientTLSConfig(),
+		SSLConfig:         sslConfig,
 		Username:          T.ServerUsername,
 		Credentials:       creds,
 		Database:          T.ServerDatabase,
@@ -84,6 +90,10 @@ func (T *Module) Cleanup() error {
 }
 
 func (T *Module) Handle(conn *fed.Conn) error {
+	if err := frontends.Authenticate(conn, nil); err != nil {
+		return err
+	}
+
 	return T.pool.Serve(conn)
 }
 
diff --git a/test/tester_test.go b/test/tester_test.go
index 21ff7c70..6ef9b431 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -1,20 +1,23 @@
 package test_test
 
 import (
-	"context"
 	"crypto/rand"
-	"encoding/hex"
+	"encoding/base64"
 	"fmt"
-	"net"
 	_ "net/http/pprof"
 	"strconv"
+	"strings"
 	"testing"
 
 	"github.com/caddyserver/caddy/v2"
+	"github.com/caddyserver/caddy/v2/caddyconfig"
 
-	"gfx.cafe/gfx/pggat/lib/auth"
 	"gfx.cafe/gfx/pggat/lib/auth/credentials"
 	"gfx.cafe/gfx/pggat/lib/gat"
+	"gfx.cafe/gfx/pggat/lib/gat/gatcaddyfile"
+	pool_handler "gfx.cafe/gfx/pggat/lib/gat/handlers/pool"
+	"gfx.cafe/gfx/pggat/lib/gat/handlers/rewrite_password"
+	"gfx.cafe/gfx/pggat/lib/gat/matchers"
 	"gfx.cafe/gfx/pggat/lib/gat/pool"
 	"gfx.cafe/gfx/pggat/lib/gat/pool/recipe"
 	"gfx.cafe/gfx/pggat/lib/gat/poolers/session"
@@ -23,21 +26,120 @@ import (
 	"gfx.cafe/gfx/pggat/test/tests"
 )
 
-func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Dialer, error) {
-	for i := 0; i < n; i++ {
-		var server gat.ServerConfig
+type dialer struct {
+	Address  string
+	Username string
+	Password string
+	Database string
+}
 
-		l, err := caddy.NetworkAddress{
-			Network: "tcp",
-		}.Listen(context.Background(), 0, net.ListenConfig{})
-		if err != nil {
-			return recipe.Dialer{}, nil
+var nextPort int
+
+func randAddress() string {
+	nextPort++
+	return "/tmp/.s.PGGAT." + strconv.Itoa(nextPort)
+}
+
+func resolveNetwork(address string) string {
+	if strings.HasPrefix(address, "/") {
+		return "unix"
+	} else {
+		return "tcp"
+	}
+}
+
+func randPassword() (string, error) {
+	var b [20]byte
+	_, err := rand.Read(b[:])
+	if err != nil {
+		return "", err
+	}
+
+	return base64.StdEncoding.EncodeToString(b[:]), nil
+}
+
+func createServer(parent dialer, poolers map[string]caddy.Module) (server gat.ServerConfig, dialers map[string]dialer, err error) {
+	address := randAddress()
+
+	server.Listen = []gat.ListenerConfig{
+		{
+			Address: address,
+		},
+	}
+
+	var password string
+	password, err = randPassword()
+	if err != nil {
+		return
+	}
+
+	server.Routes = append(
+		server.Routes,
+		gat.RouteConfig{
+			Handle: gatcaddyfile.JSONModuleObject(
+				&rewrite_password.Module{
+					Password: password,
+				},
+				gatcaddyfile.Handler,
+				"handler",
+				nil,
+			),
+		},
+	)
+
+	for name, pooler := range poolers {
+		p := pool_handler.Module{
+			Config: pool_handler.Config{
+				Pooler: gatcaddyfile.JSONModuleObject(
+					pooler,
+					gatcaddyfile.Pooler,
+					"pooler",
+					nil,
+				),
+
+				ServerAddress: parent.Address,
+
+				ServerUsername: parent.Username,
+				ServerPassword: parent.Password,
+				ServerDatabase: parent.Database,
+			},
 		}
-		ls := l.(net.Listener)
-		port := ls.Addr().(*net.TCPAddr).Port
 
+		server.Routes = append(server.Routes, gat.RouteConfig{
+			Match: gatcaddyfile.JSONModuleObject(
+				&matchers.Database{
+					Database: name,
+				},
+				gatcaddyfile.Matcher,
+				"matcher",
+				nil,
+			),
+			Handle: gatcaddyfile.JSONModuleObject(
+				&p,
+				gatcaddyfile.Handler,
+				"handler",
+				nil,
+			),
+		})
+
+		if dialers == nil {
+			dialers = make(map[string]dialer)
+		}
+		dialers[name] = dialer{
+			Address:  address,
+			Username: "pooler",
+			Password: password,
+			Database: name,
+		}
+	}
+
+	return
+}
+
+func daisyChain(config *gat.Config, control dialer, n int) (dialer, error) {
+	for i := 0; i < n; i++ {
 		poolConfig := pool.ManagementConfig{}
-		var pooler gat.Pooler
+		var pooler caddy.Module
 		if i%2 == 0 {
 			pooler = &transaction.Module{
 				ManagementConfig: poolConfig,
@@ -49,15 +151,16 @@ func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Di
 			}
 		}
 
-		// TODO(garet) add handler to server that uses pooler and control to connect
+		server, dialers, err := createServer(control, map[string]caddy.Module{
+			"pool": pooler,
+		})
 
-		control = recipe.Dialer{
-			Network:     "tcp",
-			Address:     ":" + strconv.Itoa(port),
-			Username:    "runner",
-			Credentials: creds,
-			Database:    "pool",
+		if err != nil {
+			return dialer{}, err
 		}
+
+		control = dialers["pool"]
+		config.Servers = append(config.Servers, server)
 	}
 
 	return control, nil
@@ -75,76 +178,69 @@ func TestTester(t *testing.T) {
 		Database: "postgres",
 	}
 
-	// generate random password for testing
-	var raw [32]byte
-	_, err := rand.Read(raw[:])
+	config := gat.Config{}
+
+	parent, err := daisyChain(&config, dialer{
+		Address:  "localhost:5432",
+		Username: "postgres",
+		Password: "password",
+		Database: "postgres",
+	}, 16)
 	if err != nil {
 		t.Error(err)
 		return
 	}
-	password := hex.EncodeToString(raw[:])
-	creds := credentials.Cleartext{
-		Username: "runner",
-		Password: password,
-	}
 
-	parent, err := daisyChain(creds, control, 16)
+	server, dialers, err := createServer(parent, map[string]caddy.Module{
+		"transaction": &transaction.Module{},
+		"session": &session.Module{
+			ManagementConfig: pool.ManagementConfig{
+				ServerResetQuery: "discard all",
+			},
+		},
+	})
 	if err != nil {
 		t.Error(err)
 		return
 	}
 
-	m := new(raw_pools.Module)
-	transactionPool := pool.NewPool(transaction.Apply(pool.Config{
-		Credentials: creds,
-	}))
-	transactionPool.AddRecipe("runner", recipe.NewRecipe(recipe.Config{
-		Dialer: parent,
-	}))
-	m.Add("runner", "transaction", transactionPool)
-
-	sessionPool := pool.NewPool(session.Apply(pool.Config{
-		Credentials:      creds,
-		ServerResetQuery: "discard all",
-	}))
-	sessionPool.AddRecipe("runner", recipe.NewRecipe(recipe.Config{
-		Dialer: parent,
-	}))
-	m.Add("runner", "session", sessionPool)
-
-	l := &net_listener.Module{
-		Config: net_listener.Config{
-			Network: "tcp",
-			Address: ":0",
-		},
+	config.Servers = append(config.Servers, server)
+
+	transactionDialer := recipe.Dialer{
+		Network:  resolveNetwork(dialers["transaction"].Address),
+		Address:  dialers["transaction"].Address,
+		Username: dialers["transaction"].Username,
+		Credentials: credentials.FromString(
+			dialers["transaction"].Username,
+			dialers["transaction"].Password,
+		),
+		Database: "transaction",
 	}
-	if err = l.Start(); err != nil {
-		t.Error(err)
-		return
+	sessionDialer := recipe.Dialer{
+		Network:  resolveNetwork(dialers["transaction"].Address),
+		Address:  dialers["session"].Address,
+		Username: dialers["session"].Username,
+		Credentials: credentials.FromString(
+			dialers["session"].Username,
+			dialers["session"].Password,
+		),
+		Database: "session",
 	}
-	port := l.Addr().(*net.TCPAddr).Port
 
-	server := gat.NewServer(m, l)
+	caddyConfig := caddy.Config{
+		AppsRaw: caddy.ModuleMap{
+			"pggat": caddyconfig.JSON(config, nil),
+		},
+	}
 
-	if err = server.Start(); err != nil {
+	if err = caddy.Run(&caddyConfig); err != nil {
 		t.Error(err)
 		return
 	}
 
-	transactionDialer := recipe.Dialer{
-		Network:     "tcp",
-		Address:     ":" + strconv.Itoa(port),
-		Username:    "runner",
-		Credentials: creds,
-		Database:    "transaction",
-	}
-	sessionDialer := recipe.Dialer{
-		Network:     "tcp",
-		Address:     ":" + strconv.Itoa(port),
-		Username:    "runner",
-		Credentials: creds,
-		Database:    "session",
-	}
+	defer func() {
+		_ = caddy.Stop()
+	}()
 
 	tester := test.NewTester(test.Config{
 		Stress: 8,
-- 
GitLab