From eae87249ad4bf972b4dc5ed0ad7b2d3956a5b7d8 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Mon, 2 Oct 2023 16:48:38 -0500 Subject: [PATCH] almost working --- lib/gat/handlers/discovery/config.go | 4 +- lib/gat/handlers/discovery/module.go | 18 +- lib/gat/handlers/pool/config.go | 4 +- lib/gat/handlers/pool/module.go | 20 ++- test/tester_test.go | 244 +++++++++++++++++++-------- 5 files changed, 199 insertions(+), 91 deletions(-) diff --git a/lib/gat/handlers/discovery/config.go b/lib/gat/handlers/discovery/config.go index d21a5533..b2712561 100644 --- a/lib/gat/handlers/discovery/config.go +++ b/lib/gat/handlers/discovery/config.go @@ -15,8 +15,8 @@ type Config struct { Pooler json.RawMessage `json:"pooler" caddy:"namespace=pggat.poolers inline_key=pooler"` - ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode"` - ServerSSL json.RawMessage `json:"server_ssl" caddy:"namespace=pggat.ssl.clients inline_key=provider"` + ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode,omitempty"` + ServerSSL json.RawMessage `json:"server_ssl,omitempty" caddy:"namespace=pggat.ssl.clients inline_key=provider"` ServerStartupParameters map[string]string `json:"server_startup_parameters,omitempty"` } diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go index dfb28631..bd6614d2 100644 --- a/lib/gat/handlers/discovery/module.go +++ b/lib/gat/handlers/discovery/module.go @@ -1,6 +1,7 @@ package discovery import ( + "crypto/tls" "fmt" "sync" "time" @@ -30,8 +31,8 @@ type Module struct { discoverer Discoverer - pooler gat.Pooler - ssl gat.SSLClient + pooler gat.Pooler + sslConfig *tls.Config serverStartupParameters map[strutil.CIString]string @@ -77,7 +78,8 @@ func (T *Module) Provision(ctx caddy.Context) error { if err != nil { return fmt.Errorf("loading ssl module: %v", err) } - T.ssl = val.(gat.SSLClient) + ssl := val.(gat.SSLClient) + T.sslConfig = ssl.ClientTLSConfig() } T.serverStartupParameters = make(map[strutil.CIString]string, len(T.ServerStartupParameters)) for key, value := range T.ServerStartupParameters { @@ -228,7 +230,7 @@ func (T *Module) replacePrimary(users []User, databases []string, endpoint Endpo Credentials: primaryCreds, Database: database, SSLMode: T.ServerSSLMode, - SSLConfig: T.ssl.ClientTLSConfig(), + SSLConfig: T.sslConfig, StartupParameters: T.serverStartupParameters, } @@ -263,7 +265,7 @@ func (T *Module) addReplicas(replicas map[string]Endpoint, users []User, databas Credentials: primaryCreds, Database: database, SSLMode: T.ServerSSLMode, - SSLConfig: T.ssl.ClientTLSConfig(), + SSLConfig: T.sslConfig, StartupParameters: T.serverStartupParameters, } replicaPool.AddRecipe(id, recipe.NewRecipe(recipe.Config{ @@ -302,7 +304,7 @@ func (T *Module) addReplica(users []User, databases []string, id string, endpoin Credentials: primaryCreds, Database: database, SSLMode: T.ServerSSLMode, - SSLConfig: T.ssl.ClientTLSConfig(), + SSLConfig: T.sslConfig, StartupParameters: T.serverStartupParameters, } p.AddRecipe(id, recipe.NewRecipe(recipe.Config{ @@ -334,7 +336,7 @@ func (T *Module) addUser(primaryEndpoint Endpoint, replicas map[string]Endpoint, Credentials: primaryCreds, Database: database, SSLMode: T.ServerSSLMode, - SSLConfig: T.ssl.ClientTLSConfig(), + SSLConfig: T.sslConfig, StartupParameters: T.serverStartupParameters, } @@ -393,7 +395,7 @@ func (T *Module) addDatabase(primaryEndpoint Endpoint, replicas map[string]Endpo Credentials: primaryCreds, Database: database, SSLMode: T.ServerSSLMode, - SSLConfig: T.ssl.ClientTLSConfig(), + SSLConfig: T.sslConfig, StartupParameters: T.serverStartupParameters, } diff --git a/lib/gat/handlers/pool/config.go b/lib/gat/handlers/pool/config.go index af6480cd..5a2fc336 100644 --- a/lib/gat/handlers/pool/config.go +++ b/lib/gat/handlers/pool/config.go @@ -11,8 +11,8 @@ type Config struct { // Server connect options ServerAddress string `jsonn:"server_address"` - ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode"` - ServerSSL json.RawMessage `json:"server_ssl" caddy:"namespace=pggat.ssl.clients inline_key=provider"` + ServerSSLMode bouncer.SSLMode `json:"server_ssl_mode,omitempty"` + ServerSSL json.RawMessage `json:"server_ssl,omitempty" caddy:"namespace=pggat.ssl.clients inline_key=provider"` // Server routing options ServerUsername string `json:"server_username"` diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index d6ac91ef..61b0085a 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -1,12 +1,14 @@ package pool_handler import ( + "crypto/tls" "fmt" "strings" "github.com/caddyserver/caddy/v2" "gfx.cafe/gfx/pggat/lib/auth/credentials" + "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/metrics" @@ -40,11 +42,15 @@ func (T *Module) Provision(ctx caddy.Context) error { } pooler := val.(gat.Pooler) - val, err = ctx.LoadModule(T, "ServerSSL") - if err != nil { - return fmt.Errorf("loading ssl module: %v", err) + var sslConfig *tls.Config + if T.ServerSSL != nil { + val, err = ctx.LoadModule(T, "ServerSSL") + if err != nil { + return fmt.Errorf("loading ssl module: %v", err) + } + ssl := val.(gat.SSLClient) + sslConfig = ssl.ClientTLSConfig() } - ssl := val.(gat.SSLClient) creds := credentials.FromString(T.ServerUsername, T.ServerPassword) startupParameters := make(map[strutil.CIString]string, len(T.ServerStartupParameters)) @@ -63,7 +69,7 @@ func (T *Module) Provision(ctx caddy.Context) error { Network: network, Address: T.ServerAddress, SSLMode: T.ServerSSLMode, - SSLConfig: ssl.ClientTLSConfig(), + SSLConfig: sslConfig, Username: T.ServerUsername, Credentials: creds, Database: T.ServerDatabase, @@ -84,6 +90,10 @@ func (T *Module) Cleanup() error { } func (T *Module) Handle(conn *fed.Conn) error { + if err := frontends.Authenticate(conn, nil); err != nil { + return err + } + return T.pool.Serve(conn) } diff --git a/test/tester_test.go b/test/tester_test.go index 21ff7c70..6ef9b431 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -1,20 +1,23 @@ package test_test import ( - "context" "crypto/rand" - "encoding/hex" + "encoding/base64" "fmt" - "net" _ "net/http/pprof" "strconv" + "strings" "testing" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig" - "gfx.cafe/gfx/pggat/lib/auth" "gfx.cafe/gfx/pggat/lib/auth/credentials" "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatcaddyfile" + pool_handler "gfx.cafe/gfx/pggat/lib/gat/handlers/pool" + "gfx.cafe/gfx/pggat/lib/gat/handlers/rewrite_password" + "gfx.cafe/gfx/pggat/lib/gat/matchers" "gfx.cafe/gfx/pggat/lib/gat/pool" "gfx.cafe/gfx/pggat/lib/gat/pool/recipe" "gfx.cafe/gfx/pggat/lib/gat/poolers/session" @@ -23,21 +26,120 @@ import ( "gfx.cafe/gfx/pggat/test/tests" ) -func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Dialer, error) { - for i := 0; i < n; i++ { - var server gat.ServerConfig +type dialer struct { + Address string + Username string + Password string + Database string +} - l, err := caddy.NetworkAddress{ - Network: "tcp", - }.Listen(context.Background(), 0, net.ListenConfig{}) - if err != nil { - return recipe.Dialer{}, nil +var nextPort int + +func randAddress() string { + nextPort++ + return "/tmp/.s.PGGAT." + strconv.Itoa(nextPort) +} + +func resolveNetwork(address string) string { + if strings.HasPrefix(address, "/") { + return "unix" + } else { + return "tcp" + } +} + +func randPassword() (string, error) { + var b [20]byte + _, err := rand.Read(b[:]) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(b[:]), nil +} + +func createServer(parent dialer, poolers map[string]caddy.Module) (server gat.ServerConfig, dialers map[string]dialer, err error) { + address := randAddress() + + server.Listen = []gat.ListenerConfig{ + { + Address: address, + }, + } + + var password string + password, err = randPassword() + if err != nil { + return + } + + server.Routes = append( + server.Routes, + gat.RouteConfig{ + Handle: gatcaddyfile.JSONModuleObject( + &rewrite_password.Module{ + Password: password, + }, + gatcaddyfile.Handler, + "handler", + nil, + ), + }, + ) + + for name, pooler := range poolers { + p := pool_handler.Module{ + Config: pool_handler.Config{ + Pooler: gatcaddyfile.JSONModuleObject( + pooler, + gatcaddyfile.Pooler, + "pooler", + nil, + ), + + ServerAddress: parent.Address, + + ServerUsername: parent.Username, + ServerPassword: parent.Password, + ServerDatabase: parent.Database, + }, } - ls := l.(net.Listener) - port := ls.Addr().(*net.TCPAddr).Port + server.Routes = append(server.Routes, gat.RouteConfig{ + Match: gatcaddyfile.JSONModuleObject( + &matchers.Database{ + Database: name, + }, + gatcaddyfile.Matcher, + "matcher", + nil, + ), + Handle: gatcaddyfile.JSONModuleObject( + &p, + gatcaddyfile.Handler, + "handler", + nil, + ), + }) + + if dialers == nil { + dialers = make(map[string]dialer) + } + dialers[name] = dialer{ + Address: address, + Username: "pooler", + Password: password, + Database: name, + } + } + + return +} + +func daisyChain(config *gat.Config, control dialer, n int) (dialer, error) { + for i := 0; i < n; i++ { poolConfig := pool.ManagementConfig{} - var pooler gat.Pooler + var pooler caddy.Module if i%2 == 0 { pooler = &transaction.Module{ ManagementConfig: poolConfig, @@ -49,15 +151,16 @@ func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Di } } - // TODO(garet) add handler to server that uses pooler and control to connect + server, dialers, err := createServer(control, map[string]caddy.Module{ + "pool": pooler, + }) - control = recipe.Dialer{ - Network: "tcp", - Address: ":" + strconv.Itoa(port), - Username: "runner", - Credentials: creds, - Database: "pool", + if err != nil { + return dialer{}, err } + + control = dialers["pool"] + config.Servers = append(config.Servers, server) } return control, nil @@ -75,76 +178,69 @@ func TestTester(t *testing.T) { Database: "postgres", } - // generate random password for testing - var raw [32]byte - _, err := rand.Read(raw[:]) + config := gat.Config{} + + parent, err := daisyChain(&config, dialer{ + Address: "localhost:5432", + Username: "postgres", + Password: "password", + Database: "postgres", + }, 16) if err != nil { t.Error(err) return } - password := hex.EncodeToString(raw[:]) - creds := credentials.Cleartext{ - Username: "runner", - Password: password, - } - parent, err := daisyChain(creds, control, 16) + server, dialers, err := createServer(parent, map[string]caddy.Module{ + "transaction": &transaction.Module{}, + "session": &session.Module{ + ManagementConfig: pool.ManagementConfig{ + ServerResetQuery: "discard all", + }, + }, + }) if err != nil { t.Error(err) return } - m := new(raw_pools.Module) - transactionPool := pool.NewPool(transaction.Apply(pool.Config{ - Credentials: creds, - })) - transactionPool.AddRecipe("runner", recipe.NewRecipe(recipe.Config{ - Dialer: parent, - })) - m.Add("runner", "transaction", transactionPool) - - sessionPool := pool.NewPool(session.Apply(pool.Config{ - Credentials: creds, - ServerResetQuery: "discard all", - })) - sessionPool.AddRecipe("runner", recipe.NewRecipe(recipe.Config{ - Dialer: parent, - })) - m.Add("runner", "session", sessionPool) - - l := &net_listener.Module{ - Config: net_listener.Config{ - Network: "tcp", - Address: ":0", - }, + config.Servers = append(config.Servers, server) + + transactionDialer := recipe.Dialer{ + Network: resolveNetwork(dialers["transaction"].Address), + Address: dialers["transaction"].Address, + Username: dialers["transaction"].Username, + Credentials: credentials.FromString( + dialers["transaction"].Username, + dialers["transaction"].Password, + ), + Database: "transaction", } - if err = l.Start(); err != nil { - t.Error(err) - return + sessionDialer := recipe.Dialer{ + Network: resolveNetwork(dialers["transaction"].Address), + Address: dialers["session"].Address, + Username: dialers["session"].Username, + Credentials: credentials.FromString( + dialers["session"].Username, + dialers["session"].Password, + ), + Database: "session", } - port := l.Addr().(*net.TCPAddr).Port - server := gat.NewServer(m, l) + caddyConfig := caddy.Config{ + AppsRaw: caddy.ModuleMap{ + "pggat": caddyconfig.JSON(config, nil), + }, + } - if err = server.Start(); err != nil { + if err = caddy.Run(&caddyConfig); err != nil { t.Error(err) return } - transactionDialer := recipe.Dialer{ - Network: "tcp", - Address: ":" + strconv.Itoa(port), - Username: "runner", - Credentials: creds, - Database: "transaction", - } - sessionDialer := recipe.Dialer{ - Network: "tcp", - Address: ":" + strconv.Itoa(port), - Username: "runner", - Credentials: creds, - Database: "session", - } + defer func() { + _ = caddy.Stop() + }() tester := test.NewTester(test.Config{ Stress: 8, -- GitLab