From 85cea0ef8f5c53a802a0cb19810bea53f49252e4 Mon Sep 17 00:00:00 2001 From: a <a@tuxpa.in> Date: Mon, 17 Jun 2024 00:05:57 -0500 Subject: [PATCH] noot --- lib/fed/listeners/netconnlistener/listener.go | 28 +++++++++++- lib/gat/listen.go | 44 ++++++++++++++----- lib/gat/standard/standard.go | 4 ++ test/tester_test.go | 9 +++- 4 files changed, 71 insertions(+), 14 deletions(-) diff --git a/lib/fed/listeners/netconnlistener/listener.go b/lib/fed/listeners/netconnlistener/listener.go index e2156e40..00a74e5f 100644 --- a/lib/fed/listeners/netconnlistener/listener.go +++ b/lib/fed/listeners/netconnlistener/listener.go @@ -1,16 +1,42 @@ package netconnlistener import ( + "context" + "crypto/tls" + "log" "net" + "os" + "path/filepath" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/fed/codecs/netconncodec" + "gfx.cafe/gfx/pggat/lib/gat" + "github.com/caddyserver/caddy/v2" ) type Listener struct { Listener net.Listener } +func init() { + gat.RegisterNetwork("default", ListenerFunc) +} + +func ListenerFunc(ctx context.Context, addr caddy.NetworkAddress, config *tls.Config) (fed.Listener, error) { + if addr.Network == "unix" { + if err := os.MkdirAll(filepath.Dir(addr.Host), 0o660); err != nil { + return nil, err + } + } + listener, err := addr.Listen(context.Background(), 0, net.ListenConfig{}) + if err != nil { + return nil, err + } + log.Println("got fed conn") + ncn := &Listener{Listener: listener.(net.Listener)} + return ncn, nil +} + func (listener *Listener) Accept(fn func(*fed.Conn)) error { raw, err := listener.Listener.Accept() if err != nil { @@ -23,5 +49,5 @@ func (listener *Listener) Accept(fn func(*fed.Conn)) error { return nil } func (l *Listener) Close() error { - return l.Close() + return l.Listener.Close() } diff --git a/lib/gat/listen.go b/lib/gat/listen.go index dd9dbdbc..134f0617 100644 --- a/lib/gat/listen.go +++ b/lib/gat/listen.go @@ -2,11 +2,9 @@ package gat import ( "context" + "crypto/tls" "encoding/json" "fmt" - "net" - "os" - "path/filepath" "strconv" "strings" "sync/atomic" @@ -15,9 +13,28 @@ import ( "go.uber.org/zap" "gfx.cafe/gfx/pggat/lib/fed" - "gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener" ) +var networkTypes = map[string]ListenerFunc{} + +type ListenerFunc func(ctx context.Context, addr caddy.NetworkAddress, config *tls.Config) (fed.Listener, error) + +func RegisterNetwork(network string, getListener ListenerFunc) { + network = strings.TrimSpace(strings.ToLower(network)) + + if network == "tcp" || network == "tcp4" || network == "tcp6" || + network == "udp" || network == "udp4" || network == "udp6" || + network == "unix" || network == "unixpacket" || network == "unixgram" || + strings.HasPrefix("ip:", network) || strings.HasPrefix("ip4:", network) || strings.HasPrefix("ip6:", network) { + panic("network type " + network + " is reserved") + } + + if _, ok := networkTypes[strings.ToLower(network)]; ok { + panic("network type " + network + " is already registered") + } + networkTypes[network] = getListener +} + type ListenerConfig struct { Address string `json:"address"` SSL json.RawMessage `json:"ssl,omitempty" caddy:"namespace=pggat.ssl.servers inline_key=provider"` @@ -78,19 +95,24 @@ func (T *Listener) Provision(ctx caddy.Context) error { } func (T *Listener) Start() error { - if T.networkAddress.Network == "unix" { - if err := os.MkdirAll(filepath.Dir(T.networkAddress.Host), 0o660); err != nil { - return err + listenerFunc, ok := networkTypes[T.networkAddress.Network] + if !ok { + listenerFunc, ok = networkTypes["default"] + if !ok { + return fmt.Errorf("no default listenerFunc registered. forgot to import gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener ?") } } - listener, err := T.networkAddress.Listen(context.Background(), 0, net.ListenConfig{}) + var tlsConfig *tls.Config + if T.ssl != nil { + tlsConfig = T.ssl.ServerTLSConfig() + } + listener, err := listenerFunc(context.Background(), T.networkAddress, tlsConfig) if err != nil { return err } - ncn := &netconnlistener.Listener{Listener: listener.(net.Listener)} - T.listener = ncn + T.listener = listener - T.log.Info("listening", zap.String("address", ncn.Listener.Addr().String())) + T.log.Info("listening", zap.String("address", T.networkAddress.String())) return nil } diff --git a/lib/gat/standard/standard.go b/lib/gat/standard/standard.go index be606cfc..f1c51ab5 100644 --- a/lib/gat/standard/standard.go +++ b/lib/gat/standard/standard.go @@ -2,6 +2,7 @@ package standard import ( // base server + _ "gfx.cafe/gfx/pggat/lib/gat" // matchers @@ -47,4 +48,7 @@ import ( // pools _ "gfx.cafe/gfx/pggat/lib/gat/handlers/pool/pools/basic" _ "gfx.cafe/gfx/pggat/lib/gat/handlers/pool/pools/hybrid" + + // listeners + _ "gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener" ) diff --git a/test/tester_test.go b/test/tester_test.go index 9ffc5bd8..06393771 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -13,6 +13,7 @@ import ( "github.com/caddyserver/caddy/v2/caddyconfig" "gfx.cafe/gfx/pggat/lib/auth/credentials" + "gfx.cafe/gfx/pggat/lib/bouncer" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/gatcaddyfile" "gfx.cafe/gfx/pggat/lib/gat/handlers/pool" @@ -22,6 +23,8 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" "gfx.cafe/gfx/pggat/test" "gfx.cafe/gfx/pggat/test/tests" + + _ "gfx.cafe/gfx/pggat/lib/fed/listeners/netconnlistener" ) func wrapConfig(conf basic.Config) basic.Config { @@ -97,6 +100,7 @@ func createServer(parent dialer, pools map[string]caddy.Module) (server gat.Serv Dialer: pool.Dialer{ Address: parent.Address, Username: parent.Username, + SSLMode: bouncer.SSLModeDisable, RawPassword: parent.Password, Database: parent.Database, }, @@ -164,9 +168,10 @@ func TestTester(t *testing.T) { control := pool.Dialer{ Address: "localhost:5432", Username: "postgres", + SSLMode: bouncer.SSLModeDisable, Credentials: credentials.Cleartext{ Username: "postgres", - Password: "password", + Password: "postgres", }, Database: "postgres", } @@ -176,7 +181,7 @@ func TestTester(t *testing.T) { parent, err := daisyChain(&config, dialer{ Address: "localhost:5432", Username: "postgres", - Password: "password", + Password: "postgres", Database: "postgres", }, 16) if err != nil { -- GitLab