From 85cea0ef8f5c53a802a0cb19810bea53f49252e4 Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Mon, 17 Jun 2024 00:05:57 -0500
Subject: [PATCH] noot

---
 lib/fed/listeners/netconnlistener/listener.go | 28 +++++++++++-
 lib/gat/listen.go                             | 44 ++++++++++++++-----
 lib/gat/standard/standard.go                  |  4 ++
 test/tester_test.go                           |  9 +++-
 4 files changed, 71 insertions(+), 14 deletions(-)

diff --git a/lib/fed/listeners/netconnlistener/listener.go b/lib/fed/listeners/netconnlistener/listener.go
index e2156e40..00a74e5f 100644
--- a/lib/fed/listeners/netconnlistener/listener.go
+++ b/lib/fed/listeners/netconnlistener/listener.go
@@ -1,16 +1,42 @@
 package netconnlistener
 
 import (
+	"context"
+	"crypto/tls"
+	"log"
 	"net"
+	"os"
+	"path/filepath"
 
 	"gfx.cafe/gfx/pggat/lib/fed"
 	"gfx.cafe/gfx/pggat/lib/fed/codecs/netconncodec"
+	"gfx.cafe/gfx/pggat/lib/gat"
+	"github.com/caddyserver/caddy/v2"
 )
 
 type Listener struct {
 	Listener net.Listener
 }
 
+func init() {
+	gat.RegisterNetwork("default", ListenerFunc)
+}
+
+func ListenerFunc(ctx context.Context, addr caddy.NetworkAddress, config *tls.Config) (fed.Listener, error) {
+	if addr.Network == "unix" {
+		if err := os.MkdirAll(filepath.Dir(addr.Host), 0o660); err != nil {
+			return nil, err
+		}
+	}
+	listener, err := addr.Listen(context.Background(), 0, net.ListenConfig{})
+	if err != nil {
+		return nil, err
+	}
+	log.Println("got fed conn")
+	ncn := &Listener{Listener: listener.(net.Listener)}
+	return ncn, nil
+}
+
 func (listener *Listener) Accept(fn func(*fed.Conn)) error {
 	raw, err := listener.Listener.Accept()
 	if err != nil {
@@ -23,5 +49,5 @@ func (listener *Listener) Accept(fn func(*fed.Conn)) error {
 	return nil
 }
 func (l *Listener) Close() error {
-	return l.Close()
+	return l.Listener.Close()
 }
diff --git a/lib/gat/listen.go b/lib/gat/listen.go
index dd9dbdbc..134f0617 100644
--- a/lib/gat/listen.go
+++ b/lib/gat/listen.go
@@ -2,11 +2,9 @@ package gat
 
 import (
 	"context"
+	"crypto/tls"
 	"encoding/json"
 	"fmt"
-	"net"
-	"os"
-	"path/filepath"
 	"strconv"
 	"strings"
 	"sync/atomic"
@@ -15,9 +13,28 @@ import (
 	"go.uber.org/zap"
 
 	"gfx.cafe/gfx/pggat/lib/fed"
-	"gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener"
 )
 
+var networkTypes = map[string]ListenerFunc{}
+
+type ListenerFunc func(ctx context.Context, addr caddy.NetworkAddress, config *tls.Config) (fed.Listener, error)
+
+func RegisterNetwork(network string, getListener ListenerFunc) {
+	network = strings.TrimSpace(strings.ToLower(network))
+
+	if network == "tcp" || network == "tcp4" || network == "tcp6" ||
+		network == "udp" || network == "udp4" || network == "udp6" ||
+		network == "unix" || network == "unixpacket" || network == "unixgram" ||
+		strings.HasPrefix("ip:", network) || strings.HasPrefix("ip4:", network) || strings.HasPrefix("ip6:", network) {
+		panic("network type " + network + " is reserved")
+	}
+
+	if _, ok := networkTypes[strings.ToLower(network)]; ok {
+		panic("network type " + network + " is already registered")
+	}
+	networkTypes[network] = getListener
+}
+
 type ListenerConfig struct {
 	Address        string          `json:"address"`
 	SSL            json.RawMessage `json:"ssl,omitempty" caddy:"namespace=pggat.ssl.servers inline_key=provider"`
@@ -78,19 +95,24 @@ func (T *Listener) Provision(ctx caddy.Context) error {
 }
 
 func (T *Listener) Start() error {
-	if T.networkAddress.Network == "unix" {
-		if err := os.MkdirAll(filepath.Dir(T.networkAddress.Host), 0o660); err != nil {
-			return err
+	listenerFunc, ok := networkTypes[T.networkAddress.Network]
+	if !ok {
+		listenerFunc, ok = networkTypes["default"]
+		if !ok {
+			return fmt.Errorf("no default listenerFunc registered. forgot to import gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener ?")
 		}
 	}
-	listener, err := T.networkAddress.Listen(context.Background(), 0, net.ListenConfig{})
+	var tlsConfig *tls.Config
+	if T.ssl != nil {
+		tlsConfig = T.ssl.ServerTLSConfig()
+	}
+	listener, err := listenerFunc(context.Background(), T.networkAddress, tlsConfig)
 	if err != nil {
 		return err
 	}
-	ncn := &netconnlistener.Listener{Listener: listener.(net.Listener)}
-	T.listener = ncn
+	T.listener = listener
 
-	T.log.Info("listening", zap.String("address", ncn.Listener.Addr().String()))
+	T.log.Info("listening", zap.String("address", T.networkAddress.String()))
 
 	return nil
 }
diff --git a/lib/gat/standard/standard.go b/lib/gat/standard/standard.go
index be606cfc..f1c51ab5 100644
--- a/lib/gat/standard/standard.go
+++ b/lib/gat/standard/standard.go
@@ -2,6 +2,7 @@ package standard
 
 import (
 	// base server
+
 	_ "gfx.cafe/gfx/pggat/lib/gat"
 
 	// matchers
@@ -47,4 +48,7 @@ import (
 	// pools
 	_ "gfx.cafe/gfx/pggat/lib/gat/handlers/pool/pools/basic"
 	_ "gfx.cafe/gfx/pggat/lib/gat/handlers/pool/pools/hybrid"
+
+	// listeners
+	_ "gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener"
 )
diff --git a/test/tester_test.go b/test/tester_test.go
index 9ffc5bd8..06393771 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -13,6 +13,7 @@ import (
 	"github.com/caddyserver/caddy/v2/caddyconfig"
 
 	"gfx.cafe/gfx/pggat/lib/auth/credentials"
+	"gfx.cafe/gfx/pggat/lib/bouncer"
 	"gfx.cafe/gfx/pggat/lib/gat"
 	"gfx.cafe/gfx/pggat/lib/gat/gatcaddyfile"
 	"gfx.cafe/gfx/pggat/lib/gat/handlers/pool"
@@ -22,6 +23,8 @@ import (
 	"gfx.cafe/gfx/pggat/lib/util/strutil"
 	"gfx.cafe/gfx/pggat/test"
 	"gfx.cafe/gfx/pggat/test/tests"
+
+	_ "gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener"
 )
 
 func wrapConfig(conf basic.Config) basic.Config {
@@ -97,6 +100,7 @@ func createServer(parent dialer, pools map[string]caddy.Module) (server gat.Serv
 				Dialer: pool.Dialer{
 					Address:     parent.Address,
 					Username:    parent.Username,
+					SSLMode:     bouncer.SSLModeDisable,
 					RawPassword: parent.Password,
 					Database:    parent.Database,
 				},
@@ -164,9 +168,10 @@ func TestTester(t *testing.T) {
 	control := pool.Dialer{
 		Address:  "localhost:5432",
 		Username: "postgres",
+		SSLMode:  bouncer.SSLModeDisable,
 		Credentials: credentials.Cleartext{
 			Username: "postgres",
-			Password: "password",
+			Password: "postgres",
 		},
 		Database: "postgres",
 	}
@@ -176,7 +181,7 @@ func TestTester(t *testing.T) {
 	parent, err := daisyChain(&config, dialer{
 		Address:  "localhost:5432",
 		Username: "postgres",
-		Password: "password",
+		Password: "postgres",
 		Database: "postgres",
 	}, 16)
 	if err != nil {
-- 
GitLab