From 52c050fde3c45fd30339c4f7f1816187978d1a9f Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Fri, 15 Sep 2023 15:39:51 -0500
Subject: [PATCH] do discovery complete

---
 .../modes/digitalocean_discovery/config.go    | 128 +++++++++---------
 1 file changed, 63 insertions(+), 65 deletions(-)

diff --git a/lib/gat/modes/digitalocean_discovery/config.go b/lib/gat/modes/digitalocean_discovery/config.go
index 1550d857..1e53b32d 100644
--- a/lib/gat/modes/digitalocean_discovery/config.go
+++ b/lib/gat/modes/digitalocean_discovery/config.go
@@ -4,6 +4,7 @@ import (
 	"crypto/tls"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"net"
 	"net/http"
 	"net/url"
@@ -11,6 +12,7 @@ import (
 	"time"
 
 	"gfx.cafe/util/go/gun"
+	"github.com/google/uuid"
 	"tuxpa.in/a/zlog/log"
 
 	"pggat/lib/auth/credentials"
@@ -41,8 +43,8 @@ func Load() (Config, error) {
 	return conf, nil
 }
 
-func (T *Config) ListenAndServe() error {
-	dest, err := url.Parse("https://api.digitalocean.com/v2/databases")
+func (T *Config) do(endpoint string, resp any) error {
+	dest, err := url.Parse(endpoint)
 	if err != nil {
 		return err
 	}
@@ -56,13 +58,37 @@ func (T *Config) ListenAndServe() error {
 		},
 	}
 
-	resp, err := http.DefaultClient.Do(&req)
+	res, err := http.DefaultClient.Do(&req)
 	if err != nil {
 		return err
 	}
 
-	var r ListClustersResponse
-	err = json.NewDecoder(resp.Body).Decode(&r)
+	err = json.NewDecoder(res.Body).Decode(resp)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (T *Config) ListClusters() ([]Database, error) {
+	var res ListClustersResponse
+	if err := T.do("https://api.digitalocean.com/v2/databases", &res); err != nil {
+		return nil, err
+	}
+	return res.Databases, nil
+}
+
+func (T *Config) ListReplicas(cluster uuid.UUID) ([]Database, error) {
+	var res ListReplicasResponse
+	if err := T.do(fmt.Sprintf("https://api.digitalocean.com/v2/databases/%s/replicas", cluster.String()), &res); err != nil {
+		return nil, err
+	}
+	return res.Replicas, nil
+}
+
+func (T *Config) ListenAndServe() error {
+	clusters, err := T.ListClusters()
 	if err != nil {
 		return err
 	}
@@ -79,32 +105,12 @@ func (T *Config) ListenAndServe() error {
 		}
 	}()
 
-	for _, cluster := range r.Databases {
+	for _, cluster := range clusters {
 		if cluster.Engine != "pg" {
 			continue
 		}
 
-		replicaDest, err := url.Parse("https://api.digitalocean.com/v2/databases/" + cluster.ID.String() + "/replicas")
-		if err != nil {
-			return err
-		}
-
-		replicaReq := http.Request{
-			Method: http.MethodGet,
-			URL:    replicaDest,
-			Header: http.Header{
-				"Content-Type":  []string{"application/json"},
-				"Authorization": []string{"Bearer " + T.APIKey},
-			},
-		}
-
-		replicaResp, err := http.DefaultClient.Do(&replicaReq)
-		if err != nil {
-			return err
-		}
-
-		var replicaR ListReplicasResponse
-		err = json.NewDecoder(replicaResp.Body).Decode(&replicaR)
+		replicas, err := T.ListReplicas(cluster.ID)
 		if err != nil {
 			return err
 		}
@@ -116,7 +122,7 @@ func (T *Config) ListenAndServe() error {
 			}
 
 			for _, dbname := range cluster.DBNames {
-				p := pool.NewPool(transaction.Apply(pool.Options{
+				poolOptions := transaction.Apply(pool.Options{
 					Credentials:                creds,
 					ServerReconnectInitialTime: 5 * time.Second,
 					ServerReconnectMaxTime:     5 * time.Second,
@@ -128,59 +134,51 @@ func (T *Config) ListenAndServe() error {
 						strutil.MakeCIString("standard_conforming_strings"),
 						strutil.MakeCIString("application_name"),
 					},
-				}))
+				})
+
+				p := pool.NewPool(poolOptions)
+
+				acceptOptions := backends.AcceptOptions{
+					SSLMode: bouncer.SSLModeRequire,
+					SSLConfig: &tls.Config{
+						InsecureSkipVerify: true,
+					},
+					Credentials: creds,
+					Database:    dbname,
+				}
+
 				p.AddRecipe("do", recipe.NewRecipe(recipe.Options{
 					Dialer: dialer.Net{
-						Network: "tcp",
-						Address: net.JoinHostPort(cluster.Connection.Host, strconv.Itoa(cluster.Connection.Port)),
-						AcceptOptions: backends.AcceptOptions{
-							SSLMode: bouncer.SSLModeRequire,
-							SSLConfig: &tls.Config{
-								InsecureSkipVerify: true,
-							},
-							Credentials: creds,
-							Database:    dbname,
-						},
+						Network:       "tcp",
+						Address:       net.JoinHostPort(cluster.Connection.Host, strconv.Itoa(cluster.Connection.Port)),
+						AcceptOptions: acceptOptions,
 					},
 				}))
 
 				pools.Add(user.Name, dbname, p)
+				log.Printf("registered database user=%s database=%s", user.Name, dbname)
 
-				if len(replicaR.Replicas) > 0 {
+				if len(replicas) > 0 {
+					// change pool credentials
 					creds2 := creds
 					creds2.Username = user.Name + "_ro"
-					p2 := pool.NewPool(transaction.Apply(pool.Options{
-						Credentials:                creds2,
-						ServerReconnectInitialTime: 5 * time.Second,
-						ServerReconnectMaxTime:     5 * time.Second,
-						ServerIdleTimeout:          5 * time.Minute,
-						TrackedParameters: []strutil.CIString{
-							strutil.MakeCIString("client_encoding"),
-							strutil.MakeCIString("datestyle"),
-							strutil.MakeCIString("timezone"),
-							strutil.MakeCIString("standard_conforming_strings"),
-							strutil.MakeCIString("application_name"),
-						},
-					}))
-
-					for _, replica := range replicaR.Replicas {
+					poolOptions2 := poolOptions
+					poolOptions2.Credentials = creds2
+
+					p2 := pool.NewPool(poolOptions2)
+
+					for _, replica := range replicas {
 						p2.AddRecipe("do", recipe.NewRecipe(recipe.Options{
 							Dialer: dialer.Net{
-								Network: "tcp",
-								Address: net.JoinHostPort(replica.Connection.Host, strconv.Itoa(replica.Connection.Port)),
-								AcceptOptions: backends.AcceptOptions{
-									SSLMode: bouncer.SSLModeRequire,
-									SSLConfig: &tls.Config{
-										InsecureSkipVerify: true,
-									},
-									Credentials: creds,
-									Database:    dbname,
-								},
+								Network:       "tcp",
+								Address:       net.JoinHostPort(replica.Connection.Host, strconv.Itoa(replica.Connection.Port)),
+								AcceptOptions: acceptOptions,
 							},
 						}))
 					}
 
 					pools.Add(user.Name+"_ro", dbname, p2)
+					log.Printf("registered database user=%s database=%s", user.Name+"_ro", dbname)
 				}
 			}
 		}
-- 
GitLab