From 0e8f59a5aed85cdb356d24b84c249afa17b9243b Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Mon, 25 Sep 2023 16:08:09 -0500
Subject: [PATCH] cloud sql replicas and passwords

---
 cmd/cgat/main.go                              |  21 +++
 lib/gat/modules/cloud_sql_discovery/config.go |   2 +-
 .../modules/cloud_sql_discovery/discoverer.go | 135 +++++++++++++++---
 .../digitalocean_discovery/discoverer.go      |   4 +-
 lib/gat/modules/discovery/module.go           |   5 +-
 lib/gat/modules/ssl_endpoint/module.go        |  40 ++++++
 lib/gat/pool/pool.go                          |   2 +-
 7 files changed, 186 insertions(+), 23 deletions(-)
 create mode 100644 lib/gat/modules/ssl_endpoint/module.go

diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go
index aed9645e..876a8aef 100644
--- a/cmd/cgat/main.go
+++ b/cmd/cgat/main.go
@@ -5,13 +5,16 @@ import (
 	"net/http"
 	_ "net/http/pprof"
 	"os"
+	"time"
 
 	"tuxpa.in/a/zlog/log"
 
 	"pggat/lib/gat"
+	"pggat/lib/gat/metrics"
 	"pggat/lib/gat/modules/cloud_sql_discovery"
 	"pggat/lib/gat/modules/digitalocean_discovery"
 	"pggat/lib/gat/modules/pgbouncer"
+	"pggat/lib/gat/modules/ssl_endpoint"
 	"pggat/lib/gat/modules/zalando"
 	"pggat/lib/gat/modules/zalando_operator_discovery"
 )
@@ -73,12 +76,30 @@ func main() {
 
 	var server gat.Server
 
+	// load and add main module
 	module, err := loadModule(runMode)
 	if err != nil {
 		panic(err)
 	}
 	server.AddModule(module)
 
+	// back up ssl endpoint (for modules that don't have endpoints by default such as discovery)
+	ep, err := ssl_endpoint.NewModule()
+	if err != nil {
+		panic(err)
+	}
+	server.AddModule(ep)
+
+	go func() {
+		var m metrics.Server
+		for {
+			time.Sleep(1 * time.Minute)
+			server.ReadMetrics(&m)
+			log.Printf("%s", m.String())
+			m.Clear()
+		}
+	}()
+
 	err = server.ListenAndServe()
 	if err != nil {
 		panic(err)
diff --git a/lib/gat/modules/cloud_sql_discovery/config.go b/lib/gat/modules/cloud_sql_discovery/config.go
index 5ff96306..c6f84166 100644
--- a/lib/gat/modules/cloud_sql_discovery/config.go
+++ b/lib/gat/modules/cloud_sql_discovery/config.go
@@ -17,7 +17,7 @@ func Load() (Config, error) {
 	var conf Config
 	gun.Load(&conf)
 	if conf.Project == "" {
-		return Config{}, errors.New("expected google cloud project id")
+		return Config{}, errors.New("expected PGGAT_GC_PROJECT")
 	}
 	return conf, nil
 }
diff --git a/lib/gat/modules/cloud_sql_discovery/discoverer.go b/lib/gat/modules/cloud_sql_discovery/discoverer.go
index 7658aa41..848ad4d6 100644
--- a/lib/gat/modules/cloud_sql_discovery/discoverer.go
+++ b/lib/gat/modules/cloud_sql_discovery/discoverer.go
@@ -2,14 +2,26 @@ package cloud_sql_discovery
 
 import (
 	"context"
+	"crypto/tls"
 	"net"
 	"strings"
 
 	sqladmin "google.golang.org/api/sqladmin/v1beta4"
 
+	"pggat/lib/auth/credentials"
+	"pggat/lib/bouncer"
+	"pggat/lib/bouncer/backends/v0"
+	"pggat/lib/bouncer/bouncers/v2"
+	"pggat/lib/fed"
 	"pggat/lib/gat/modules/discovery"
+	"pggat/lib/gsql"
 )
 
+type authQueryResult struct {
+	Username string `sql:"0"`
+	Password string `sql:"1"`
+}
+
 type Discoverer struct {
 	config Config
 
@@ -28,24 +40,59 @@ func NewDiscoverer(config Config) (*Discoverer, error) {
 	}, nil
 }
 
-func (T *Discoverer) instanceToCluster(instance *sqladmin.DatabaseInstance) (discovery.Cluster, error) {
-	var address string
-	for _, ip := range instance.IpAddresses {
+func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, replicas ...*sqladmin.DatabaseInstance) (discovery.Cluster, error) {
+	var primaryAddress string
+	for _, ip := range primary.IpAddresses {
 		if ip.Type != T.config.IpAddressType {
 			continue
 		}
-		address = net.JoinHostPort(ip.IpAddress, "5432")
+		primaryAddress = net.JoinHostPort(ip.IpAddress, "5432")
 	}
 
 	c := discovery.Cluster{
-		ID: instance.Name,
+		ID: primary.Name,
 		Primary: discovery.Endpoint{
 			Network: "tcp",
-			Address: address,
+			Address: primaryAddress,
 		},
+		Replicas: make(map[string]discovery.Endpoint, len(replicas)),
+	}
+
+	for _, replica := range replicas {
+		var replicaAddress string
+		for _, ip := range primary.IpAddresses {
+			if ip.Type != T.config.IpAddressType {
+				continue
+			}
+			replicaAddress = net.JoinHostPort(ip.IpAddress, "5432")
+		}
+		c.Replicas[replica.Name] = discovery.Endpoint{
+			Network: "tcp",
+			Address: replicaAddress,
+		}
+	}
+
+	databases, err := T.google.Databases.List(T.config.Project, primary.Name).Do()
+	if err != nil {
+		return discovery.Cluster{}, err
+	}
+	c.Databases = make([]string, 0, len(databases.Items))
+	for _, database := range databases.Items {
+		c.Databases = append(c.Databases, database.Name)
+	}
+
+	if len(c.Databases) == 0 {
+		return c, nil
 	}
 
-	users, err := T.google.Users.List(T.config.Project, instance.Name).Do()
+	var admin fed.Conn
+	defer func() {
+		if admin != nil {
+			_ = admin.Close()
+		}
+	}()
+
+	users, err := T.google.Users.List(T.config.Project, primary.Name).Do()
 	if err != nil {
 		return discovery.Cluster{}, err
 	}
@@ -55,7 +102,54 @@ func (T *Discoverer) instanceToCluster(instance *sqladmin.DatabaseInstance) (dis
 		if user.Name == T.config.AuthUser {
 			password = T.config.AuthPassword
 		} else {
-			// TODO(garet) lookup password
+			// dial admin connection
+			if admin == nil {
+				raw, err := net.Dial("tcp", primaryAddress)
+				if err != nil {
+					return discovery.Cluster{}, err
+				}
+				admin = fed.WrapNetConn(raw)
+				_, err = backends.Accept(&backends.AcceptContext{
+					Conn: admin,
+					Options: backends.AcceptOptions{
+						SSLMode: bouncer.SSLModePrefer,
+						SSLConfig: &tls.Config{
+							InsecureSkipVerify: true,
+						},
+						Username:    T.config.AuthUser,
+						Credentials: credentials.FromString(T.config.AuthUser, T.config.AuthPassword),
+						Database:    c.Databases[0],
+					},
+				})
+				if err != nil {
+					return discovery.Cluster{}, err
+				}
+			}
+
+			var result authQueryResult
+			client := new(gsql.Client)
+			err := gsql.ExtendedQuery(client, &result, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", user.Name)
+			if err != nil {
+				return discovery.Cluster{}, err
+			}
+			err = client.Close()
+			if err != nil {
+				return discovery.Cluster{}, err
+			}
+
+			initialPacket, err := client.ReadPacket(true, nil)
+			if err != nil {
+				return discovery.Cluster{}, err
+			}
+			_, err, err2 := bouncers.Bounce(client, admin, initialPacket)
+			if err != nil {
+				return discovery.Cluster{}, err
+			}
+			if err2 != nil {
+				return discovery.Cluster{}, err2
+			}
+
+			password = result.Password
 		}
 
 		c.Users = append(c.Users, discovery.User{
@@ -64,15 +158,6 @@ func (T *Discoverer) instanceToCluster(instance *sqladmin.DatabaseInstance) (dis
 		})
 	}
 
-	databases, err := T.google.Databases.List(T.config.Project, instance.Name).Do()
-	if err != nil {
-		return discovery.Cluster{}, err
-	}
-	c.Databases = make([]string, 0, len(databases.Items))
-	for _, database := range databases.Items {
-		c.Databases = append(c.Databases, database.Name)
-	}
-
 	return c, nil
 }
 
@@ -84,11 +169,25 @@ func (T *Discoverer) Clusters() ([]discovery.Cluster, error) {
 
 	res := make([]discovery.Cluster, 0, len(clusters.Items))
 	for _, cluster := range clusters.Items {
+		if cluster.InstanceType != "CLOUD_SQL_INSTANCE" {
+			continue
+		}
+
 		if !strings.HasPrefix(cluster.DatabaseVersion, "POSTGRES_") {
 			continue
 		}
 
-		c, err := T.instanceToCluster(cluster)
+		replicas := make([]*sqladmin.DatabaseInstance, 0, len(cluster.ReplicaNames))
+		for _, replicaName := range cluster.ReplicaNames {
+			for _, replica := range clusters.Items {
+				if replica.Name == replicaName {
+					replicas = append(replicas, replica)
+					break
+				}
+			}
+		}
+
+		c, err := T.instanceToCluster(cluster, replicas...)
 		if err != nil {
 			return nil, err
 		}
diff --git a/lib/gat/modules/digitalocean_discovery/discoverer.go b/lib/gat/modules/digitalocean_discovery/discoverer.go
index a334f4d5..eab2cf7d 100644
--- a/lib/gat/modules/digitalocean_discovery/discoverer.go
+++ b/lib/gat/modules/digitalocean_discovery/discoverer.go
@@ -52,9 +52,9 @@ func (T Discoverer) Clusters() ([]discovery.Cluster, error) {
 			Users:     make([]discovery.User, 0, len(cluster.Users)),
 		}
 
-		for _, user := range c.Users {
+		for _, user := range cluster.Users {
 			c.Users = append(c.Users, discovery.User{
-				Username: user.Username,
+				Username: user.Name,
 				Password: user.Password,
 			})
 		}
diff --git a/lib/gat/modules/discovery/module.go b/lib/gat/modules/discovery/module.go
index 2da07997..a0e0ea34 100644
--- a/lib/gat/modules/discovery/module.go
+++ b/lib/gat/modules/discovery/module.go
@@ -179,7 +179,10 @@ func (T *Module) discoverLoop() {
 		case next := <-T.config.Discoverer.Updated():
 			T.updated(T.clusters[next.ID], next)
 		case <-reconcile:
-			_ = T.reconcile() // TODO(garet) do something with this error
+			err := T.reconcile()
+			if err != nil {
+				log.Printf("failed to reconcile: %v", err)
+			}
 		}
 	}
 }
diff --git a/lib/gat/modules/ssl_endpoint/module.go b/lib/gat/modules/ssl_endpoint/module.go
new file mode 100644
index 00000000..71f259dc
--- /dev/null
+++ b/lib/gat/modules/ssl_endpoint/module.go
@@ -0,0 +1,40 @@
+package ssl_endpoint
+
+import (
+	"pggat/lib/gat"
+	"pggat/lib/util/strutil"
+)
+
+type Module struct{}
+
+func NewModule() (*Module, error) {
+	return &Module{}, nil
+}
+
+func (T *Module) GatModule() {}
+
+func (T *Module) Endpoints() []gat.Endpoint {
+	// TODO(garet) gen ssl keys
+
+	return []gat.Endpoint{
+		{
+			Network: "tcp",
+			Address: ":5432",
+			AcceptOptions: gat.FrontendAcceptOptions{
+				SSLRequired: false,
+				AllowedStartupOptions: []strutil.CIString{
+					strutil.MakeCIString("client_encoding"),
+					strutil.MakeCIString("datestyle"),
+					strutil.MakeCIString("timezone"),
+					strutil.MakeCIString("standard_conforming_strings"),
+					strutil.MakeCIString("application_name"),
+					strutil.MakeCIString("extra_float_digits"),
+					strutil.MakeCIString("options"),
+				},
+			},
+		},
+	}
+}
+
+var _ gat.Module = (*Module)(nil)
+var _ gat.Listener = (*Module)(nil)
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index 73a01a69..445da1db 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -489,7 +489,7 @@ func (T *Pool) Close() {
 
 	// remove clients
 	for _, client := range T.clients {
-		T.removeClient(client)
+		T.removeClientL1(client)
 	}
 
 	// remove recipes
-- 
GitLab