diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 5e7ea6fd91e1ae26bea0771730ff6b8db294f365..aed9645e151ee8ac2afc24de99d13d73f5c4bf7a 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -8,13 +8,12 @@ import ( "tuxpa.in/a/zlog/log" + "pggat/lib/gat" "pggat/lib/gat/modules/cloud_sql_discovery" "pggat/lib/gat/modules/digitalocean_discovery" "pggat/lib/gat/modules/pgbouncer" "pggat/lib/gat/modules/zalando" "pggat/lib/gat/modules/zalando_operator_discovery" - - "pggat/lib/gat" ) func loadModule(mode string) (gat.Module, error) { diff --git a/lib/gat/endpoint.go b/lib/gat/endpoint.go new file mode 100644 index 0000000000000000000000000000000000000000..a815745d1ce1d68b3f99b089b34bdee297b02d82 --- /dev/null +++ b/lib/gat/endpoint.go @@ -0,0 +1,12 @@ +package gat + +import "pggat/lib/bouncer/frontends/v0" + +type FrontendAcceptOptions = frontends.AcceptOptions + +type Endpoint struct { + Network string + Address string + + AcceptOptions FrontendAcceptOptions +} diff --git a/lib/gat/listener.go b/lib/gat/listener.go index 07aee0914253a44ecacc70933c61ad774c2fc3f9..71893dbab6362b65b1fce73af27ccd8384fd7370 100644 --- a/lib/gat/listener.go +++ b/lib/gat/listener.go @@ -1,35 +1,7 @@ package gat -import ( - "net" +type Listener interface { + Module - "pggat/lib/bouncer/frontends/v0" - "pggat/lib/fed" -) - -type FrontendAcceptOptions = frontends.AcceptOptions - -type Listener struct { - Listener net.Listener - Options FrontendAcceptOptions -} - -func (T Listener) Accept() (fed.Conn, error) { - raw, err := T.Listener.Accept() - if err != nil { - return nil, err - } - conn := fed.WrapNetConn(raw) - _, err = frontends.Accept(&frontends.AcceptContext{ - Conn: conn, - Options: T.Options, - }) - if err != nil { - return nil, err - } - return conn, nil -} - -func (T Listener) Close() error { - return T.Listener.Close() + Endpoints() []Endpoint } diff --git a/lib/gat/module.go b/lib/gat/module.go index aec1c3de98cf1c6a6cc78fedb44945b6ecee40fd..0ec77c1684af9d0f78fd46c73493fa1ebbf33306 100644 --- a/lib/gat/module.go +++ b/lib/gat/module.go @@ -1,8 +1,5 @@ package gat -type ModuleInfo struct { -} - type Module interface { - GatModule() ModuleInfo + GatModule() } diff --git a/lib/gat/modules/discovery/config.go b/lib/gat/modules/discovery/config.go index 45390982ce87b6cebae7f751e841042437a5f33b..91f497806d06a6ad8570715b42349f0996cb35b2 100644 --- a/lib/gat/modules/discovery/config.go +++ b/lib/gat/modules/discovery/config.go @@ -1,6 +1,8 @@ package discovery -import "time" +import ( + "time" +) type Config struct { // ReconcilePeriod is how often the module should check for changes. 0 = disable diff --git a/lib/gat/modules/discovery/module.go b/lib/gat/modules/discovery/module.go index 6907d10128e213a6ae0489f0f34d778c276bfb18..262c3ad2a3c3da4487ca8df4633e57cb5b342e46 100644 --- a/lib/gat/modules/discovery/module.go +++ b/lib/gat/modules/discovery/module.go @@ -6,15 +6,14 @@ import ( ) type Module struct { + config Config } func NewModule(config Config) (*Module, error) { } -func (T *Module) GatModule() gat.ModuleInfo { - // TODO(garet) -} +func (T *Module) GatModule() {} func (T *Module) ReadMetrics(metrics *metrics.Pools) { // TODO implement me diff --git a/lib/gat/modules/pgbouncer/module.go b/lib/gat/modules/pgbouncer/module.go index 7784633a2c115bd543198f4cc9b1b7d2b777f62d..4c4e92385fe9fd1b7f77f158b04eb14ddd0858f4 100644 --- a/lib/gat/modules/pgbouncer/module.go +++ b/lib/gat/modules/pgbouncer/module.go @@ -13,6 +13,7 @@ import ( "pggat/lib/auth/credentials" "pggat/lib/bouncer/backends/v0" + "pggat/lib/bouncer/frontends/v0" "pggat/lib/gat" "pggat/lib/gat/metrics" "pggat/lib/gat/pool" @@ -247,9 +248,73 @@ func (T *Module) ReadMetrics(metrics *metrics.Pools) { }) } -func (T *Module) GatModule() gat.ModuleInfo { - return gat.ModuleInfo{} // TODO(garet) +func (T *Module) Endpoints() []gat.Endpoint { + trackedParameters := append([]strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, T.config.PgBouncer.TrackExtraParameters...) + + allowedStartupParameters := append(trackedParameters, T.config.PgBouncer.IgnoreStartupParameters...) + var sslConfig *tls.Config + if T.config.PgBouncer.ClientTLSCertFile != "" && T.config.PgBouncer.ClientTLSKeyFile != "" { + certificate, err := tls.LoadX509KeyPair(T.config.PgBouncer.ClientTLSCertFile, T.config.PgBouncer.ClientTLSKeyFile) + if err != nil { + log.Printf("error loading X509 keypair: %v", err) + } else { + sslConfig = &tls.Config{ + Certificates: []tls.Certificate{ + certificate, + }, + } + } + } + + acceptOptions := frontends.AcceptOptions{ + SSLRequired: T.config.PgBouncer.ClientTLSSSLMode.IsRequired(), + SSLConfig: sslConfig, + AllowedStartupOptions: allowedStartupParameters, + } + + var endpoints []gat.Endpoint + + if T.config.PgBouncer.ListenAddr != "" { + listenAddr := T.config.PgBouncer.ListenAddr + if listenAddr == "*" { + listenAddr = "" + } + + listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.config.PgBouncer.ListenPort)) + + endpoints = append(endpoints, gat.Endpoint{ + Network: "tcp", + Address: listen, + AcceptOptions: acceptOptions, + }) + } + + // listen on unix socket + dir := T.config.PgBouncer.UnixSocketDir + port := T.config.PgBouncer.ListenPort + + if !strings.HasSuffix(dir, "/") { + dir = dir + "/" + } + dir = dir + ".s.PGSQL." + strconv.Itoa(port) + + endpoints = append(endpoints, gat.Endpoint{ + Network: "unix", + Address: dir, + AcceptOptions: acceptOptions, + }) + + return endpoints } +func (T *Module) GatModule() {} + var _ gat.Module = (*Module)(nil) var _ gat.Provider = (*Module)(nil) +var _ gat.Listener = (*Module)(nil) diff --git a/lib/gat/server.go b/lib/gat/server.go index 1f72ee2b1512a2e7080ab05be17a0a41bca336d5..9c18e69701316e0f8cfd0010a22dbd5d9025e8cf 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -3,12 +3,14 @@ package gat import ( "errors" "io" + "net" "tuxpa.in/a/zlog/log" "pggat/lib/bouncer/frontends/v0" "pggat/lib/fed" "pggat/lib/gat/metrics" + "pggat/lib/util/beforeexit" "pggat/lib/util/flip" "pggat/lib/util/maps" ) @@ -16,6 +18,7 @@ import ( type Server struct { modules []Module providers []Provider + listeners []Listener keys maps.RWLocked[[8]byte, *Pool] } @@ -25,6 +28,9 @@ func (T *Server) AddModule(module Module) { if provider, ok := module.(Provider); ok { T.providers = append(T.providers, provider) } + if listener, ok := module.(Listener); ok { + T.listeners = append(T.listeners, listener) + } } func (T *Server) cancel(key [8]byte) error { @@ -85,41 +91,70 @@ func (T *Server) serve(conn fed.Conn, params frontends.AcceptParams) error { return p.Serve(conn, params.InitialParameters, auth.BackendKey) } -func (T *Server) Serve(listener Listener) error { - raw, err := listener.Listener.Accept() +func (T *Server) accept(raw net.Conn, acceptOptions FrontendAcceptOptions) { + conn := fed.WrapNetConn(raw) + + defer func() { + _ = conn.Close() + }() + + ctx := frontends.AcceptContext{ + Conn: conn, + Options: acceptOptions, + } + params, err2 := frontends.Accept(&ctx) + if err2 != nil { + log.Print("error accepting client: ", err2) + return + } + + err := T.serve(conn, params) + if err != nil && !errors.Is(err, io.EOF) { + log.Print("error serving client: ", err) + return + } +} + +func (T *Server) listenAndServe(endpoint Endpoint) error { + listener, err := net.Listen(endpoint.Network, endpoint.Address) if err != nil { return err } - conn := fed.WrapNetConn(raw) + if endpoint.Network == "unix" { + beforeexit.Run(func() { + _ = listener.Close() + }) + } - go func() { - defer func() { - _ = conn.Close() - }() + log.Printf("listening on %s(%s)", endpoint.Network, endpoint.Address) - ctx := frontends.AcceptContext{ - Conn: conn, - Options: listener.Options, - } - params, err2 := frontends.Accept(&ctx) - if err2 != nil { - log.Print("error accepting client: ", err2) - return + for { + raw, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + break + } } - err := T.serve(conn, params) - if err != nil && !errors.Is(err, io.EOF) { - log.Print("error serving client: ", err) - return - } - }() + go T.accept(raw, endpoint.AcceptOptions) + } + return nil } func (T *Server) ListenAndServe() error { var b flip.Bank - // TODO(garet) add listeners to bank + if len(T.listeners) > 0 { + l := T.listeners[0] + endpoints := l.Endpoints() + for _, endpoint := range endpoints { + e := endpoint + b.Queue(func() error { + return T.listenAndServe(e) + }) + } + } return b.Wait() }