diff --git a/lib/gat/modules/raw_pools/module.go b/lib/gat/modules/raw_pools/module.go new file mode 100644 index 0000000000000000000000000000000000000000..2a1fe7170c1b34315135fb150fda6d1666b364ef --- /dev/null +++ b/lib/gat/modules/raw_pools/module.go @@ -0,0 +1,56 @@ +package raw_pools + +import ( + "sync" + + "pggat/lib/gat" + "pggat/lib/gat/metrics" + "pggat/lib/gat/pool" + "pggat/lib/util/maps" +) + +type Module struct { + pools maps.TwoKey[string, string, *pool.Pool] + mu sync.RWMutex +} + +func NewModule() (*Module, error) { + return &Module{}, nil +} + +func (T *Module) GatModule() {} + +func (T *Module) Add(user, database string, p *pool.Pool) { + T.mu.Lock() + defer T.mu.Unlock() + + T.pools.Store(user, database, p) +} + +func (T *Module) Remove(user, database string) { + T.mu.Lock() + defer T.mu.Unlock() + + T.pools.Delete(user, database) +} + +func (T *Module) Lookup(user, database string) *gat.Pool { + T.mu.RLock() + defer T.mu.RUnlock() + + p, _ := T.pools.Load(user, database) + return p +} + +func (T *Module) ReadMetrics(metrics *metrics.Pools) { + T.mu.RLock() + defer T.mu.RUnlock() + + T.pools.Range(func(_ string, _ string, p *pool.Pool) bool { + p.ReadMetrics(&metrics.Pool) + return true + }) +} + +var _ gat.Module = (*Module)(nil) +var _ gat.Provider = (*Module)(nil) diff --git a/lib/gat/server.go b/lib/gat/server.go index 9c18e69701316e0f8cfd0010a22dbd5d9025e8cf..8e81694ef86833a2354a2ec011b1e9870906cb57 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -115,19 +115,23 @@ func (T *Server) accept(raw net.Conn, acceptOptions FrontendAcceptOptions) { } } -func (T *Server) listenAndServe(endpoint Endpoint) error { - listener, err := net.Listen(endpoint.Network, endpoint.Address) +func (T *Server) Listen(network, address string) (net.Listener, error) { + listener, err := net.Listen(network, address) if err != nil { - return err + return nil, err } - if endpoint.Network == "unix" { + if network == "unix" { beforeexit.Run(func() { _ = listener.Close() }) } - log.Printf("listening on %s(%s)", endpoint.Network, endpoint.Address) + log.Printf("listening on %s(%s)", network, address) + + return listener, nil +} +func (T *Server) Serve(listener net.Listener, acceptOptions FrontendAcceptOptions) error { for { raw, err := listener.Accept() if err != nil { @@ -136,7 +140,7 @@ func (T *Server) listenAndServe(endpoint Endpoint) error { } } - go T.accept(raw, endpoint.AcceptOptions) + go T.accept(raw, acceptOptions) } return nil @@ -151,7 +155,11 @@ func (T *Server) ListenAndServe() error { for _, endpoint := range endpoints { e := endpoint b.Queue(func() error { - return T.listenAndServe(e) + listener, err := T.Listen(e.Network, e.Address) + if err != nil { + return err + } + return T.Serve(listener, e.AcceptOptions) }) } } diff --git a/test/tester_test.go b/test/tester_test.go index eefc26a13d65eb237fe699a2311d96b221cdf5b2..8e983833c0f2550b5ef958df6f627fde89808122 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -14,8 +14,8 @@ import ( "pggat/lib/bouncer/backends/v0" "pggat/lib/bouncer/frontends/v0" "pggat/lib/gat" + "pggat/lib/gat/modules/raw_pools" "pggat/lib/gat/pool" - "pggat/lib/gat/pool/dialer" "pggat/lib/gat/pool/pools/session" "pggat/lib/gat/pool/pools/transaction" "pggat/lib/gat/pool/recipe" @@ -23,9 +23,9 @@ import ( "pggat/test/tests" ) -func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, error) { +func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Dialer, error) { for i := 0; i < n; i++ { - var g gat.PoolsMap + var server gat.Server var options = pool.Options{ Credentials: creds, @@ -41,22 +41,28 @@ func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, p.AddRecipe("runner", recipe.NewRecipe(recipe.Options{ Dialer: control, })) - g.Add("runner", "pool", p) - listener, err := gat.Listen("tcp", ":0", frontends.AcceptOptions{}) + m, err := raw_pools.NewModule() if err != nil { - return dialer.Net{}, err + return recipe.Dialer{}, err } - port := listener.Listener.Addr().(*net.TCPAddr).Port + m.Add("runner", "pool", p) + server.AddModule(m) + + listener, err := server.Listen("tcp", ":0") + if err != nil { + return recipe.Dialer{}, err + } + port := listener.Addr().(*net.TCPAddr).Port go func() { - err := gat.Serve(listener, gat.NewKeyedPools(&g)) + err := server.Serve(listener, frontends.AcceptOptions{}) if err != nil { panic(err) } }() - control = dialer.Net{ + control = recipe.Dialer{ Network: "tcp", Address: ":" + strconv.Itoa(port), AcceptOptions: backends.AcceptOptions{ @@ -71,7 +77,7 @@ func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, } func TestTester(t *testing.T) { - control := dialer.Net{ + control := recipe.Dialer{ Network: "tcp", Address: "localhost:5432", AcceptOptions: backends.AcceptOptions{ @@ -103,15 +109,20 @@ func TestTester(t *testing.T) { return } - var g gat.PoolsMap + var server gat.Server + m, err := raw_pools.NewModule() + if err != nil { + t.Error(err) + return + } transactionPool := pool.NewPool(transaction.Apply(pool.Options{ Credentials: creds, })) transactionPool.AddRecipe("runner", recipe.NewRecipe(recipe.Options{ Dialer: parent, })) - g.Add("runner", "transaction", transactionPool) + m.Add("runner", "transaction", transactionPool) sessionPool := pool.NewPool(session.Apply(pool.Options{ Credentials: creds, @@ -120,23 +131,25 @@ func TestTester(t *testing.T) { sessionPool.AddRecipe("runner", recipe.NewRecipe(recipe.Options{ Dialer: parent, })) - g.Add("runner", "session", sessionPool) + m.Add("runner", "session", sessionPool) + + server.AddModule(m) - listener, err := gat.Listen("tcp", ":0", frontends.AcceptOptions{}) + listener, err := server.Listen("tcp", ":0") if err != nil { t.Error(err) return } - port := listener.Listener.Addr().(*net.TCPAddr).Port + port := listener.Addr().(*net.TCPAddr).Port go func() { - err := gat.Serve(listener, gat.NewKeyedPools(&g)) + err := server.Serve(listener, frontends.AcceptOptions{}) if err != nil { t.Error(err) } }() - transactionDialer := dialer.Net{ + transactionDialer := recipe.Dialer{ Network: "tcp", Address: ":" + strconv.Itoa(port), AcceptOptions: backends.AcceptOptions{ @@ -145,7 +158,7 @@ func TestTester(t *testing.T) { Database: "transaction", }, } - sessionDialer := dialer.Net{ + sessionDialer := recipe.Dialer{ Network: "tcp", Address: ":" + strconv.Itoa(port), AcceptOptions: backends.AcceptOptions{