diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 4bc38292811216020c39cb94f5985ecc34b72cbf..875640e5e15276b5ffa19dd4ceae6c45d3046ca5 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "errors" "net/http" _ "net/http/pprof" @@ -11,69 +12,114 @@ import ( "tuxpa.in/a/zlog/log" + "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/metrics" "gfx.cafe/gfx/pggat/lib/gat/modules/cloud_sql_discovery" "gfx.cafe/gfx/pggat/lib/gat/modules/digitalocean_discovery" + "gfx.cafe/gfx/pggat/lib/gat/modules/net_listener" "gfx.cafe/gfx/pggat/lib/gat/modules/pgbouncer" - "gfx.cafe/gfx/pggat/lib/gat/modules/ssl_endpoint" "gfx.cafe/gfx/pggat/lib/gat/modules/zalando" "gfx.cafe/gfx/pggat/lib/gat/modules/zalando_operator_discovery" + "gfx.cafe/gfx/pggat/lib/util/certs" + "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func loadModule(mode string) (gat.Module, error) { +func addSSLEndpoint(server *gat.Server) error { + // back up ssl endpoint (for modules that don't have endpoints by default such as discovery) + cert, err := certs.SelfSign() + if err != nil { + return err + } + server.AddModule(&net_listener.Module{ + Config: net_listener.Config{ + Network: "tcp", + Address: ":5432", + AcceptOptions: frontends.AcceptOptions{ + SSLRequired: false, + SSLConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + AllowedStartupOptions: []strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + strutil.MakeCIString("extra_float_digits"), + strutil.MakeCIString("options"), + }, + }, + }, + }) + + return nil +} + +func addEnvModule(server *gat.Server, mode string) error { switch mode { case "pggat": conf, err := pgbouncer.Load(os.Args[1]) if err != nil { - return nil, err + return err } - return &pgbouncer.Module{ + + server.AddModule(&pgbouncer.Module{ Config: conf, - }, nil + }) case "pgbouncer": conf, err := pgbouncer.Load(os.Args[1]) if err != nil { - return nil, err + return err } - return &pgbouncer.Module{ + + server.AddModule(&pgbouncer.Module{ Config: conf, - }, nil + }) case "pgbouncer_spilo": conf, err := zalando.Load() if err != nil { - return nil, err + return err } - return &zalando.Module{ + + server.AddModule(&zalando.Module{ Config: conf, - }, nil + }) case "zalando_kubernetes_operator": conf, err := zalando_operator_discovery.Load() if err != nil { - return nil, err + return err } - return &zalando_operator_discovery.Module{ + + server.AddModule(&zalando_operator_discovery.Module{ Config: conf, - }, nil + }) + return addSSLEndpoint(server) case "google_cloud_sql": conf, err := cloud_sql_discovery.Load() if err != nil { - return nil, err + return err } - return &cloud_sql_discovery.Module{ + + server.AddModule(&cloud_sql_discovery.Module{ Config: conf, - }, nil + }) + return addSSLEndpoint(server) case "digitalocean_databases": conf, err := digitalocean_discovery.Load() if err != nil { - return nil, err + return err } - return &digitalocean_discovery.Module{ + + server.AddModule(&digitalocean_discovery.Module{ Config: conf, - }, nil + }) + return addSSLEndpoint(server) default: - return nil, errors.New("Unknown PGGAT_RUN_MODE: " + mode) + return errors.New("Unknown PGGAT_RUN_MODE: " + mode) } + + return nil } func main() { @@ -89,12 +135,8 @@ func main() { log.Printf("Starting pggat (%s)...", runMode) var server gat.Server - defer func() { - if err := server.Stop(); err != nil { - log.Printf("error stopping: %v", err) - } - }() + // handle interrupts c := make(chan os.Signal, 2) signal.Notify(c, os.Interrupt, syscall.SIGTERM) @@ -104,17 +146,14 @@ func main() { if err := server.Stop(); err != nil { log.Printf("error stopping: %v", err) } + + os.Exit(0) }() // load and add main module - module, err := loadModule(runMode) - if err != nil { + if err := addEnvModule(&server, runMode); err != nil { panic(err) } - server.AddModule(module) - - // back up ssl endpoint (for modules that don't have endpoints by default such as discovery) - server.AddModule(&ssl_endpoint.Module{}) go func() { var m metrics.Server @@ -126,9 +165,9 @@ func main() { } }() - err = server.Start() - if err != nil { + if err := server.Start(); err != nil { panic(err) } - return + + select {} } diff --git a/lib/gat/endpoint.go b/lib/gat/endpoint.go deleted file mode 100644 index 269f4a8af118222c03efbde114bedf4aaaecae5e..0000000000000000000000000000000000000000 --- a/lib/gat/endpoint.go +++ /dev/null @@ -1,12 +0,0 @@ -package gat - -import "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" - -type FrontendAcceptOptions = frontends.AcceptOptions - -type Endpoint struct { - Network string - Address string - - AcceptOptions FrontendAcceptOptions -} diff --git a/lib/gat/exposed.go b/lib/gat/exposed.go deleted file mode 100644 index fd1556207a408b1cf933602253e450d62b8f46b4..0000000000000000000000000000000000000000 --- a/lib/gat/exposed.go +++ /dev/null @@ -1,7 +0,0 @@ -package gat - -type Exposed interface { - Module - - Endpoints() []Endpoint -} diff --git a/lib/gat/listener.go b/lib/gat/listener.go new file mode 100644 index 0000000000000000000000000000000000000000..071508852ce4bbf82de2a6ba10493dab6994b673 --- /dev/null +++ b/lib/gat/listener.go @@ -0,0 +1,17 @@ +package gat + +import ( + "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" + "gfx.cafe/gfx/pggat/lib/fed" +) + +type AcceptedConn struct { + Conn fed.Conn + Params frontends.AcceptParams +} + +type Listener interface { + Module + + Accept() []<-chan AcceptedConn +} diff --git a/lib/gat/modules/cloud_sql_discovery/module.go b/lib/gat/modules/cloud_sql_discovery/module.go index d904096f2c59cdeb69996bd1090f1814c4b94443..0e595be5fa59841dd1ae9a2750bd1d95f505082b 100644 --- a/lib/gat/modules/cloud_sql_discovery/module.go +++ b/lib/gat/modules/cloud_sql_discovery/module.go @@ -5,6 +5,7 @@ import ( "time" "gfx.cafe/gfx/pggat/lib/bouncer" + "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/modules/discovery" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -44,3 +45,5 @@ func (T *Module) Start() error { } return T.Module.Start() } + +var _ gat.Starter = (*Module)(nil) diff --git a/lib/gat/modules/digitalocean_discovery/module.go b/lib/gat/modules/digitalocean_discovery/module.go index d004dd40a4e28735f56e423cc39d9062a9428eec..2e1cb43e628fd9fde15da5891fed1d7bf3884b7a 100644 --- a/lib/gat/modules/digitalocean_discovery/module.go +++ b/lib/gat/modules/digitalocean_discovery/module.go @@ -5,6 +5,7 @@ import ( "time" "gfx.cafe/gfx/pggat/lib/bouncer" + "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/modules/discovery" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -44,3 +45,5 @@ func (T *Module) Start() error { } return T.Module.Start() } + +var _ gat.Starter = (*Module)(nil) diff --git a/lib/gat/modules/discovery/module.go b/lib/gat/modules/discovery/module.go index 5a9068752ef3279d1e92188d4cd49300a80e60a2..dc7d799bb22eba35be694f54e2640a19fc5e4251 100644 --- a/lib/gat/modules/discovery/module.go +++ b/lib/gat/modules/discovery/module.go @@ -49,6 +49,14 @@ func (T *Module) Stop() error { return errors.New("discoverer not running") } close(T.closed) + + T.mu.Lock() + defer T.mu.Unlock() + T.pools.Range(func(user string, database string, p *pool.Pool) bool { + p.Close() + T.pools.Delete(user, database) + return true + }) return nil } @@ -494,3 +502,5 @@ func (T *Module) Lookup(user, database string) *gat.Pool { var _ gat.Module = (*Module)(nil) var _ gat.Provider = (*Module)(nil) +var _ gat.Starter = (*Module)(nil) +var _ gat.Stopper = (*Module)(nil) diff --git a/lib/gat/modules/net_listener/config.go b/lib/gat/modules/net_listener/config.go new file mode 100644 index 0000000000000000000000000000000000000000..6484dc15d4a046b883f798a1926fb497324fec3b --- /dev/null +++ b/lib/gat/modules/net_listener/config.go @@ -0,0 +1,9 @@ +package net_listener + +import "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" + +type Config struct { + Network string + Address string + AcceptOptions frontends.AcceptOptions +} diff --git a/lib/gat/modules/net_listener/module.go b/lib/gat/modules/net_listener/module.go new file mode 100644 index 0000000000000000000000000000000000000000..17dd194ceb71e4d679019f75055a840e5c5a97f9 --- /dev/null +++ b/lib/gat/modules/net_listener/module.go @@ -0,0 +1,95 @@ +package net_listener + +import ( + "errors" + "net" + + "tuxpa.in/a/zlog/log" + + "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" + "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/gat" +) + +type Module struct { + Config + + listener net.Listener + accepted chan gat.AcceptedConn +} + +func (*Module) GatModule() {} + +func (T *Module) Start() error { + if T.listener != nil { + // in case this listener was started early + return nil + } + + var err error + T.listener, err = net.Listen(T.Network, T.Address) + if err != nil { + return err + } + log.Printf("listening on %v", T.listener.Addr()) + + T.accepted = make(chan gat.AcceptedConn) + go T.acceptLoop() + + return nil +} + +func (T *Module) Stop() error { + return T.listener.Close() +} + +func (T *Module) Addr() net.Addr { + if T.listener == nil { + return nil + } + return T.listener.Addr() +} + +func (T *Module) accept(raw net.Conn) { + conn := fed.WrapNetConn(raw) + ctx := frontends.AcceptContext{ + Conn: conn, + Options: T.AcceptOptions, + } + params, err := frontends.Accept(&ctx) + if err != nil { + log.Printf("failed to accept conn: %v", err) + return + } + _ = params // TODO(garet) + T.accepted <- gat.AcceptedConn{ + Conn: conn, + Params: params, + } +} + +func (T *Module) acceptLoop() { + for { + conn, err := T.listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("failed to accept conn: %v", err) + continue + } + + T.accept(conn) + } +} + +func (T *Module) Accept() []<-chan gat.AcceptedConn { + return []<-chan gat.AcceptedConn{ + T.accepted, + } +} + +var _ gat.Module = (*Module)(nil) +var _ gat.Listener = (*Module)(nil) +var _ gat.Starter = (*Module)(nil) +var _ gat.Stopper = (*Module)(nil) diff --git a/lib/gat/modules/pgbouncer/module.go b/lib/gat/modules/pgbouncer/module.go index 53ab543ed2ac364a41888e92d0e6153baeed8885..d3330fff1dfc741b09516f961a61ca25ea63e341 100644 --- a/lib/gat/modules/pgbouncer/module.go +++ b/lib/gat/modules/pgbouncer/module.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "strings" + "sync" "time" "tuxpa.in/a/zlog/log" @@ -16,6 +17,7 @@ import ( "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/metrics" + "gfx.cafe/gfx/pggat/lib/gat/modules/net_listener" "gfx.cafe/gfx/pggat/lib/gat/pool" "gfx.cafe/gfx/pggat/lib/gat/pool/pools/session" "gfx.cafe/gfx/pggat/lib/gat/pool/pools/transaction" @@ -34,6 +36,106 @@ type Module struct { Config pools maps.TwoKey[string, string, *gat.Pool] + mu sync.RWMutex + + tcpListener net_listener.Module + unixListener net_listener.Module +} + +func (T *Module) Start() error { + trackedParameters := append([]strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, T.PgBouncer.TrackExtraParameters...) + + allowedStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...) + var sslConfig *tls.Config + if T.PgBouncer.ClientTLSCertFile != "" && T.PgBouncer.ClientTLSKeyFile != "" { + certificate, err := tls.LoadX509KeyPair(T.PgBouncer.ClientTLSCertFile, T.PgBouncer.ClientTLSKeyFile) + if err != nil { + log.Printf("error loading X509 keypair: %v", err) + } else { + sslConfig = &tls.Config{ + Certificates: []tls.Certificate{ + certificate, + }, + } + } + } + + acceptOptions := frontends.AcceptOptions{ + SSLRequired: T.PgBouncer.ClientTLSSSLMode.IsRequired(), + SSLConfig: sslConfig, + AllowedStartupOptions: allowedStartupParameters, + } + + if T.PgBouncer.ListenAddr != "" { + listenAddr := T.PgBouncer.ListenAddr + if listenAddr == "*" { + listenAddr = "" + } + + listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.PgBouncer.ListenPort)) + + T.tcpListener = net_listener.Module{ + Config: net_listener.Config{ + Network: "tcp", + Address: listen, + AcceptOptions: acceptOptions, + }, + } + if err := T.tcpListener.Start(); err != nil { + return err + } + } + + // listen on unix socket + dir := T.PgBouncer.UnixSocketDir + port := T.PgBouncer.ListenPort + + if !strings.HasSuffix(dir, "/") { + dir = dir + "/" + } + dir = dir + ".s.PGSQL." + strconv.Itoa(port) + + T.unixListener = net_listener.Module{ + Config: net_listener.Config{ + Network: "unix", + Address: dir, + AcceptOptions: acceptOptions, + }, + } + if err := T.unixListener.Start(); err != nil { + return err + } + + return nil +} + +func (T *Module) Stop() error { + var err error + if T.PgBouncer.ListenAddr != "" { + if err2 := T.tcpListener.Stop(); err2 != nil { + err = err2 + } + } + + if err2 := T.unixListener.Stop(); err2 != nil { + err = err2 + } + + T.mu.Lock() + defer T.mu.Unlock() + T.pools.Range(func(user string, database string, p *gat.Pool) bool { + p.Close() + T.pools.Delete(user, database) + return true + }) + + return err } func (T *Module) getPassword(user, database string) (string, bool) { @@ -151,6 +253,8 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { } p := pool.NewPool(poolOptions) + T.mu.Lock() + defer T.mu.Unlock() T.pools.Store(user, database, p) var d recipe.Dialer @@ -236,79 +340,27 @@ func (T *Module) Lookup(user, database string) *gat.Pool { } func (T *Module) ReadMetrics(metrics *metrics.Pools) { + T.mu.RLock() + defer T.mu.RUnlock() T.pools.Range(func(_ string, _ string, p *gat.Pool) bool { p.ReadMetrics(&metrics.Pool) return true }) } -func (T *Module) Endpoints() []gat.Endpoint { - trackedParameters := append([]strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), - }, T.PgBouncer.TrackExtraParameters...) - - allowedStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...) - var sslConfig *tls.Config - if T.PgBouncer.ClientTLSCertFile != "" && T.PgBouncer.ClientTLSKeyFile != "" { - certificate, err := tls.LoadX509KeyPair(T.PgBouncer.ClientTLSCertFile, T.PgBouncer.ClientTLSKeyFile) - if err != nil { - log.Printf("error loading X509 keypair: %v", err) - } else { - sslConfig = &tls.Config{ - Certificates: []tls.Certificate{ - certificate, - }, - } - } - } - - acceptOptions := frontends.AcceptOptions{ - SSLRequired: T.PgBouncer.ClientTLSSSLMode.IsRequired(), - SSLConfig: sslConfig, - AllowedStartupOptions: allowedStartupParameters, - } - - var endpoints []gat.Endpoint - +func (T *Module) Accept() []<-chan gat.AcceptedConn { + var accept []<-chan gat.AcceptedConn if T.PgBouncer.ListenAddr != "" { - listenAddr := T.PgBouncer.ListenAddr - if listenAddr == "*" { - listenAddr = "" - } - - listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.PgBouncer.ListenPort)) - - endpoints = append(endpoints, gat.Endpoint{ - Network: "tcp", - Address: listen, - AcceptOptions: acceptOptions, - }) + accept = append(accept, T.tcpListener.Accept()...) } - - // listen on unix socket - dir := T.PgBouncer.UnixSocketDir - port := T.PgBouncer.ListenPort - - if !strings.HasSuffix(dir, "/") { - dir = dir + "/" - } - dir = dir + ".s.PGSQL." + strconv.Itoa(port) - - endpoints = append(endpoints, gat.Endpoint{ - Network: "unix", - Address: dir, - AcceptOptions: acceptOptions, - }) - - return endpoints + accept = append(accept, T.unixListener.Accept()...) + return accept } func (T *Module) GatModule() {} var _ gat.Module = (*Module)(nil) var _ gat.Provider = (*Module)(nil) -var _ gat.Exposed = (*Module)(nil) +var _ gat.Listener = (*Module)(nil) +var _ gat.Starter = (*Module)(nil) +var _ gat.Stopper = (*Module)(nil) diff --git a/lib/gat/modules/ssl_endpoint/module.go b/lib/gat/modules/ssl_endpoint/module.go deleted file mode 100644 index 5880e684bcb1320ae249c58fd2aab993d17cf90c..0000000000000000000000000000000000000000 --- a/lib/gat/modules/ssl_endpoint/module.go +++ /dev/null @@ -1,105 +0,0 @@ -package ssl_endpoint - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "math/big" - "net" - "time" - - "tuxpa.in/a/zlog/log" - - "gfx.cafe/gfx/pggat/lib/gat" - "gfx.cafe/gfx/pggat/lib/util/strutil" -) - -type Module struct { - config *tls.Config -} - -func (T *Module) generateKeys() error { - // generate private key - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return err - } - - keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment - - notBefore := time.Now() - notAfter := notBefore.Add(3 * 30 * 24 * time.Hour) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - if err != nil { - return err - } - - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"GFX Labs"}, - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: keyUsage, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - // TODO(garet) - template.IPAddresses = append(template.IPAddresses, net.ParseIP("192.168.1.1")) - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return err - } - - var cert tls.Certificate - cert.PrivateKey = priv - cert.Certificate = append(cert.Certificate, derBytes) - - T.config = &tls.Config{ - Certificates: []tls.Certificate{ - cert, - }, - } - return nil -} - -func (T *Module) GatModule() {} - -func (T *Module) Endpoints() []gat.Endpoint { - if T.config == nil { - if err := T.generateKeys(); err != nil { - log.Printf("failed to generate ssl certificate: %v", err) - } - } - - return []gat.Endpoint{ - { - Network: "tcp", - Address: ":5432", - AcceptOptions: gat.FrontendAcceptOptions{ - SSLRequired: false, - SSLConfig: T.config, - AllowedStartupOptions: []strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), - strutil.MakeCIString("extra_float_digits"), - strutil.MakeCIString("options"), - }, - }, - }, - } -} - -var _ gat.Module = (*Module)(nil) -var _ gat.Exposed = (*Module)(nil) diff --git a/lib/gat/modules/zalando/module.go b/lib/gat/modules/zalando/module.go index 0115e362973f42b0097d52c6f9585770c6be0d5a..b9e79ea0e87c2871b9b5ee40fb33132bcc930f67 100644 --- a/lib/gat/modules/zalando/module.go +++ b/lib/gat/modules/zalando/module.go @@ -4,6 +4,7 @@ import ( "fmt" "gfx.cafe/gfx/pggat/lib/bouncer" + "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/modules/pgbouncer" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -64,5 +65,7 @@ func (T *Module) Start() error { Config: pgb, } - return nil + return T.Module.Start() } + +var _ gat.Starter = (*Module)(nil) diff --git a/lib/gat/modules/zalando_operator_discovery/module.go b/lib/gat/modules/zalando_operator_discovery/module.go index 68d2cef6a406dcbc7da91b6aa05b693c54ca4aff..246670062757c90fbc67e7d4c3837c131005e55c 100644 --- a/lib/gat/modules/zalando_operator_discovery/module.go +++ b/lib/gat/modules/zalando_operator_discovery/module.go @@ -5,6 +5,7 @@ import ( "time" "gfx.cafe/gfx/pggat/lib/bouncer" + "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/modules/discovery" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -45,3 +46,5 @@ func (T *Module) Start() error { return T.Module.Start() } + +var _ gat.Starter = (*Module)(nil) diff --git a/lib/gat/server.go b/lib/gat/server.go index aebac874e542c491ff06c48b2c87eb5ea8569cc1..bfd9536226b8924deba44205c62db5b7aee383ef 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -3,26 +3,23 @@ package gat import ( "errors" "io" - "net" "tuxpa.in/a/zlog/log" "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/metrics" - "gfx.cafe/gfx/pggat/lib/util/flip" + "gfx.cafe/gfx/pggat/lib/util/chans" "gfx.cafe/gfx/pggat/lib/util/maps" ) type Server struct { modules []Module providers []Provider - exposed []Exposed + listeners []Listener starters []Starter stoppers []Stopper - listeners []net.Listener - keys maps.RWLocked[[8]byte, *Pool] } @@ -31,8 +28,8 @@ func (T *Server) AddModule(module Module) { if provider, ok := module.(Provider); ok { T.providers = append(T.providers, provider) } - if listener, ok := module.(Exposed); ok { - T.exposed = append(T.exposed, listener) + if listener, ok := module.(Listener); ok { + T.listeners = append(T.listeners, listener) } if starter, ok := module.(Starter); ok { T.starters = append(T.starters, starter) @@ -100,78 +97,6 @@ func (T *Server) serve(conn fed.Conn, params frontends.AcceptParams) error { return p.Serve(conn, params.InitialParameters, auth.BackendKey) } -func (T *Server) accept(raw net.Conn, acceptOptions FrontendAcceptOptions) { - conn := fed.WrapNetConn(raw) - - defer func() { - _ = conn.Close() - }() - - ctx := frontends.AcceptContext{ - Conn: conn, - Options: acceptOptions, - } - params, err2 := frontends.Accept(&ctx) - if err2 != nil { - log.Print("error accepting client: ", err2) - return - } - - err := T.serve(conn, params) - if err != nil && !errors.Is(err, io.EOF) { - log.Print("error serving client: ", err) - return - } -} - -func (T *Server) startListening(network, address string) (net.Listener, error) { - listener, err := net.Listen(network, address) - if err != nil { - return nil, err - } - T.listeners = append(T.listeners, listener) - - log.Printf("listening on %s(%s)", network, address) - - return listener, nil -} - -func (T *Server) listen(listener net.Listener, acceptOptions FrontendAcceptOptions) error { - for { - raw, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - } - - go T.accept(raw, acceptOptions) - } - - return nil -} - -func (T *Server) listenAndServe() error { - var b flip.Bank - - if len(T.exposed) > 0 { - l := T.exposed[0] - endpoints := l.Endpoints() - for _, endpoint := range endpoints { - e := endpoint - b.Queue(func() error { - listener, err := T.startListening(e.Network, e.Address) - if err != nil { - return err - } - return T.listen(listener, e.AcceptOptions) - }) - } - } - - return b.Wait() -} - func (T *Server) ReadMetrics(m *metrics.Server) { for _, provider := range T.providers { provider.ReadMetrics(&m.Pools) @@ -185,17 +110,32 @@ func (T *Server) Start() error { } } - return T.listenAndServe() -} + var accept []<-chan AcceptedConn -func (T *Server) Stop() error { - var err error for _, listener := range T.listeners { - if err2 := listener.Close(); err2 != nil { - err = err2 - } + accept = append(accept, listener.Accept()...) } + go func() { + acceptor := chans.NewMultiRecv(accept) + for { + accepted, ok := acceptor.Recv() + if !ok { + break + } + go func() { + if err := T.serve(accepted.Conn, accepted.Params); err != nil && !errors.Is(err, io.EOF) { + log.Printf("failed to serve client: %v", err) + } + }() + } + }() + + return nil +} + +func (T *Server) Stop() error { + var err error for _, stopper := range T.stoppers { if err2 := stopper.Stop(); err2 != nil { err = err2 diff --git a/lib/util/certs/self.go b/lib/util/certs/self.go new file mode 100644 index 0000000000000000000000000000000000000000..e4bd2b8ff47905083a5fb36b757960251f9250d9 --- /dev/null +++ b/lib/util/certs/self.go @@ -0,0 +1,57 @@ +package certs + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "time" +) + +func SelfSign() (tls.Certificate, error) { + // generate private key + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return tls.Certificate{}, err + } + + keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment + + notBefore := time.Now() + notAfter := notBefore.Add(3 * 30 * 24 * time.Hour) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"GFX Labs"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + // TODO(garet) + template.IPAddresses = append(template.IPAddresses, net.ParseIP("127.0.0.1")) + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + var cert tls.Certificate + cert.PrivateKey = priv + cert.Certificate = append(cert.Certificate, derBytes) + + return cert, nil +} diff --git a/lib/util/chans/multi.go b/lib/util/chans/multi.go new file mode 100644 index 0000000000000000000000000000000000000000..6182115d442aebd635f72c061a4bc8434f330753 --- /dev/null +++ b/lib/util/chans/multi.go @@ -0,0 +1,39 @@ +package chans + +import ( + "reflect" + + "gfx.cafe/gfx/pggat/lib/util/slices" +) + +type MultiRecv[T any] struct { + cases []reflect.SelectCase +} + +func NewMultiRecv[T any](cases []<-chan T) *MultiRecv[T] { + c := make([]reflect.SelectCase, 0, len(cases)) + for _, ch := range cases { + c = append(c, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ch), + }) + } + return &MultiRecv[T]{ + cases: c, + } +} + +func (c *MultiRecv[T]) Recv() (T, bool) { + for { + if len(c.cases) == 0 { + return *new(T), false + } + + idx, value, ok := reflect.Select(c.cases) + if !ok { + c.cases = slices.DeleteIndex(c.cases, idx) + continue + } + return value.Interface().(T), true + } +} diff --git a/lib/util/slices/remove.go b/lib/util/slices/remove.go index 1848711f5065da4f78b972da1e0675d17ecd24bc..04ca24c193ccdf31c21187f0b9452dc78f865346 100644 --- a/lib/util/slices/remove.go +++ b/lib/util/slices/remove.go @@ -8,7 +8,12 @@ func Remove[T comparable](slice []T, item T) []T { if i == -1 { return slice } - copy(slice[i:], slice[i+1:]) + return RemoveIndex(slice, i) +} + +func RemoveIndex[T any](slice []T, idx int) []T { + item := slice[idx] + copy(slice[idx:], slice[idx+1:]) slice[len(slice)-1] = item return slice[:len(slice)-1] } @@ -19,7 +24,11 @@ func Delete[T comparable](slice []T, item T) []T { if i == -1 { return slice } - copy(slice[i:], slice[i+1:]) + return DeleteIndex(slice, i) +} + +func DeleteIndex[T any](slice []T, idx int) []T { + copy(slice[idx:], slice[idx+1:]) slice[len(slice)-1] = *new(T) return slice[:len(slice)-1] } diff --git a/test/tester_test.go b/test/tester_test.go index dc53e46b6768b02c95e2c89f5a684525d16cfff3..0bf1ff4eeb61c254a40dc8a325f35aef106885cf 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -12,8 +12,8 @@ import ( "gfx.cafe/gfx/pggat/lib/auth" "gfx.cafe/gfx/pggat/lib/auth/credentials" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" - "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/modules/net_listener" "gfx.cafe/gfx/pggat/lib/gat/modules/raw_pools" "gfx.cafe/gfx/pggat/lib/gat/pool" "gfx.cafe/gfx/pggat/lib/gat/pool/pools/session" @@ -46,18 +46,21 @@ func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Di m.Add("runner", "pool", p) server.AddModule(m) - listener, err := server.listen("tcp", ":0") - if err != nil { + l := &net_listener.Module{ + Config: net_listener.Config{ + Network: "tcp", + Address: ":0", + }, + } + if err := l.Start(); err != nil { return recipe.Dialer{}, err } - port := listener.Addr().(*net.TCPAddr).Port + port := l.Addr().(*net.TCPAddr).Port + server.AddModule(l) - go func() { - err := server.serve(listener, frontends.AcceptOptions{}) - if err != nil { - panic(err) - } - }() + if err := server.Start(); err != nil { + panic(err) + } control = recipe.Dialer{ Network: "tcp", @@ -128,19 +131,24 @@ func TestTester(t *testing.T) { server.AddModule(m) - listener, err := server.listen("tcp", ":0") - if err != nil { + l := &net_listener.Module{ + Config: net_listener.Config{ + Network: "tcp", + Address: ":0", + }, + } + if err = l.Start(); err != nil { t.Error(err) return } - port := listener.Addr().(*net.TCPAddr).Port + port := l.Addr().(*net.TCPAddr).Port - go func() { - err := server.serve(listener, frontends.AcceptOptions{}) - if err != nil { - t.Error(err) - } - }() + server.AddModule(l) + + if err = server.Start(); err != nil { + t.Error(err) + return + } transactionDialer := recipe.Dialer{ Network: "tcp",