diff --git a/cmd/caddygat/main.go b/cmd/caddygat/main.go
index 775c9e1504af14dd82da6544fb16555581d3ef83..abe74c6c5e8475f7808bc43e1345111a585f57d4 100644
--- a/cmd/caddygat/main.go
+++ b/cmd/caddygat/main.go
@@ -7,7 +7,6 @@ import (
_ "gfx.cafe/gfx/pggat/lib/gat/modules/cloud_sql_discovery"
_ "gfx.cafe/gfx/pggat/lib/gat/modules/digitalocean_discovery"
_ "gfx.cafe/gfx/pggat/lib/gat/modules/pgbouncer"
- _ "gfx.cafe/gfx/pggat/lib/gat/modules/ssl_endpoint"
_ "gfx.cafe/gfx/pggat/lib/gat/modules/zalando"
_ "gfx.cafe/gfx/pggat/lib/gat/modules/zalando_operator_discovery"
)
diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go
index 875640e5e15276b5ffa19dd4ceae6c45d3046ca5..3d5f40436ea905e5d29b57bc350bd3ece578766f 100644
--- a/cmd/cgat/main.go
+++ b/cmd/cgat/main.go
@@ -25,13 +25,14 @@ import (
"gfx.cafe/gfx/pggat/lib/util/strutil"
)
-func addSSLEndpoint(server *gat.Server) error {
+func createSSLModule() (gat.Module, error) {
// back up ssl endpoint (for modules that don't have endpoints by default such as discovery)
cert, err := certs.SelfSign()
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&net_listener.Module{
+
+ return &net_listener.Module{
Config: net_listener.Config{
Network: "tcp",
Address: ":5432",
@@ -51,75 +52,68 @@ func addSSLEndpoint(server *gat.Server) error {
},
},
},
- })
-
- return nil
+ }, nil
}
-func addEnvModule(server *gat.Server, mode string) error {
+func loadModule(mode string) (gat.Module, error) {
switch mode {
case "pggat":
conf, err := pgbouncer.Load(os.Args[1])
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&pgbouncer.Module{
+ return &pgbouncer.Module{
Config: conf,
- })
+ }, nil
case "pgbouncer":
conf, err := pgbouncer.Load(os.Args[1])
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&pgbouncer.Module{
+ return &pgbouncer.Module{
Config: conf,
- })
+ }, nil
case "pgbouncer_spilo":
conf, err := zalando.Load()
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&zalando.Module{
+ return &zalando.Module{
Config: conf,
- })
+ }, nil
case "zalando_kubernetes_operator":
conf, err := zalando_operator_discovery.Load()
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&zalando_operator_discovery.Module{
+ return &zalando_operator_discovery.Module{
Config: conf,
- })
- return addSSLEndpoint(server)
+ }, nil
case "google_cloud_sql":
conf, err := cloud_sql_discovery.Load()
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&cloud_sql_discovery.Module{
+ return &cloud_sql_discovery.Module{
Config: conf,
- })
- return addSSLEndpoint(server)
+ }, nil
case "digitalocean_databases":
conf, err := digitalocean_discovery.Load()
if err != nil {
- return err
+ return nil, err
}
- server.AddModule(&digitalocean_discovery.Module{
+ return &digitalocean_discovery.Module{
Config: conf,
- })
- return addSSLEndpoint(server)
+ }, nil
default:
- return errors.New("Unknown PGGAT_RUN_MODE: " + mode)
+ return nil, errors.New("Unknown PGGAT_RUN_MODE: " + mode)
}
-
- return nil
}
func main() {
@@ -134,7 +128,24 @@ func main() {
log.Printf("Starting pggat (%s)...", runMode)
- var server gat.Server
+ // load modules
+ var modules []gat.Module
+
+ module, err := loadModule(runMode)
+ if err != nil {
+ panic(err)
+ }
+ modules = append(modules, module)
+
+ if _, ok := module.(gat.Listener); !ok {
+ endpoint, err := createSSLModule()
+ if err != nil {
+ panic(err)
+ }
+ modules = append(modules, endpoint)
+ }
+
+ server := gat.NewServer(modules...)
// handle interrupts
c := make(chan os.Signal, 2)
@@ -150,11 +161,6 @@ func main() {
os.Exit(0)
}()
- // load and add main module
- if err := addEnvModule(&server, runMode); err != nil {
- panic(err)
- }
-
go func() {
var m metrics.Server
for {
diff --git a/lib/gat/modules/net_listener/module.go b/lib/gat/modules/net_listener/module.go
index 17dd194ceb71e4d679019f75055a840e5c5a97f9..27ab19dc6c8e4fd97fe8950931375b514a5f9080 100644
--- a/lib/gat/modules/net_listener/module.go
+++ b/lib/gat/modules/net_listener/module.go
@@ -40,7 +40,10 @@ func (T *Module) Start() error {
}
func (T *Module) Stop() error {
- return T.listener.Close()
+ if err := T.listener.Close(); err != nil {
+ return err
+ }
+ return nil
}
func (T *Module) Addr() net.Addr {
@@ -79,7 +82,7 @@ func (T *Module) acceptLoop() {
continue
}
- T.accept(conn)
+ go T.accept(conn)
}
}
diff --git a/lib/gat/server.go b/lib/gat/server.go
index bfd9536226b8924deba44205c62db5b7aee383ef..99c143b11a97455138b25da0b41e9b1a862a2d63 100644
--- a/lib/gat/server.go
+++ b/lib/gat/server.go
@@ -20,23 +20,31 @@ type Server struct {
starters []Starter
stoppers []Stopper
+ done chan struct{}
+
keys maps.RWLocked[[8]byte, *Pool]
}
-func (T *Server) AddModule(module Module) {
- T.modules = append(T.modules, module)
- if provider, ok := module.(Provider); ok {
- T.providers = append(T.providers, provider)
- }
- if listener, ok := module.(Listener); ok {
- T.listeners = append(T.listeners, listener)
- }
- if starter, ok := module.(Starter); ok {
- T.starters = append(T.starters, starter)
- }
- if stopper, ok := module.(Stopper); ok {
- T.stoppers = append(T.stoppers, stopper)
+func NewServer(modules ...Module) *Server {
+ server := new(Server)
+
+ for _, module := range modules {
+ server.modules = append(server.modules, module)
+ if provider, ok := module.(Provider); ok {
+ server.providers = append(server.providers, provider)
+ }
+ if listener, ok := module.(Listener); ok {
+ server.listeners = append(server.listeners, listener)
+ }
+ if starter, ok := module.(Starter); ok {
+ server.starters = append(server.starters, starter)
+ }
+ if stopper, ok := module.(Stopper); ok {
+ server.stoppers = append(server.stoppers, stopper)
+ }
}
+
+ return server
}
func (T *Server) cancel(key [8]byte) error {
@@ -110,28 +118,32 @@ func (T *Server) Start() error {
}
}
+ T.done = make(chan struct{})
+
+ go T.acceptLoop()
+
+ return nil
+}
+
+func (T *Server) acceptLoop() {
var accept []<-chan AcceptedConn
for _, listener := range T.listeners {
accept = append(accept, listener.Accept()...)
}
- go func() {
- acceptor := chans.NewMultiRecv(accept)
- for {
- accepted, ok := acceptor.Recv()
- if !ok {
- break
- }
- go func() {
- if err := T.serve(accepted.Conn, accepted.Params); err != nil && !errors.Is(err, io.EOF) {
- log.Printf("failed to serve client: %v", err)
- }
- }()
+ acceptor := chans.NewMultiRecv(accept, T.done)
+ for {
+ accepted, ok := acceptor.Recv()
+ if !ok {
+ break
}
- }()
-
- return nil
+ go func() {
+ if err := T.serve(accepted.Conn, accepted.Params); err != nil && !errors.Is(err, io.EOF) {
+ log.Printf("failed to serve client: %v", err)
+ }
+ }()
+ }
}
func (T *Server) Stop() error {
@@ -142,5 +154,7 @@ func (T *Server) Stop() error {
}
}
+ close(T.done)
+
return err
}
diff --git a/lib/util/chans/multi.go b/lib/util/chans/multi.go
index 6182115d442aebd635f72c061a4bc8434f330753..2f054b80e52384b2ec31a1e77b0fa40cc3d95c8b 100644
--- a/lib/util/chans/multi.go
+++ b/lib/util/chans/multi.go
@@ -10,8 +10,12 @@ type MultiRecv[T any] struct {
cases []reflect.SelectCase
}
-func NewMultiRecv[T any](cases []<-chan T) *MultiRecv[T] {
- c := make([]reflect.SelectCase, 0, len(cases))
+func NewMultiRecv[T any](cases []<-chan T, done <-chan struct{}) *MultiRecv[T] {
+ c := make([]reflect.SelectCase, 0, len(cases)+1)
+ c = append(c, reflect.SelectCase{
+ Dir: reflect.SelectRecv,
+ Chan: reflect.ValueOf(done),
+ })
for _, ch := range cases {
c = append(c, reflect.SelectCase{
Dir: reflect.SelectRecv,
@@ -25,12 +29,12 @@ func NewMultiRecv[T any](cases []<-chan T) *MultiRecv[T] {
func (c *MultiRecv[T]) Recv() (T, bool) {
for {
- if len(c.cases) == 0 {
- return *new(T), false
- }
-
idx, value, ok := reflect.Select(c.cases)
if !ok {
+ if idx == 0 {
+ // done triggered
+ return *new(T), false
+ }
c.cases = slices.DeleteIndex(c.cases, idx)
continue
}
diff --git a/test/tester_test.go b/test/tester_test.go
index 0bf1ff4eeb61c254a40dc8a325f35aef106885cf..20f239779303d618507d8d31f584966f26688aec 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -25,8 +25,6 @@ import (
func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Dialer, error) {
for i := 0; i < n; i++ {
- var server gat.Server
-
var options = pool.Options{
Credentials: creds,
}
@@ -44,7 +42,6 @@ func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Di
m := new(raw_pools.Module)
m.Add("runner", "pool", p)
- server.AddModule(m)
l := &net_listener.Module{
Config: net_listener.Config{
@@ -56,7 +53,8 @@ func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Di
return recipe.Dialer{}, err
}
port := l.Addr().(*net.TCPAddr).Port
- server.AddModule(l)
+
+ server := gat.NewServer(m, l)
if err := server.Start(); err != nil {
panic(err)
@@ -109,8 +107,6 @@ func TestTester(t *testing.T) {
return
}
- var server gat.Server
-
m := new(raw_pools.Module)
transactionPool := pool.NewPool(transaction.Apply(pool.Options{
Credentials: creds,
@@ -129,8 +125,6 @@ func TestTester(t *testing.T) {
}))
m.Add("runner", "session", sessionPool)
- server.AddModule(m)
-
l := &net_listener.Module{
Config: net_listener.Config{
Network: "tcp",
@@ -143,7 +137,7 @@ func TestTester(t *testing.T) {
}
port := l.Addr().(*net.TCPAddr).Port
- server.AddModule(l)
+ server := gat.NewServer(m, l)
if err = server.Start(); err != nil {
t.Error(err)