diff --git a/README.md b/README.md index 6437612f51515a0654ea38357606d10656b2a2f2..071474ed2211dca9504b2f534c2cf91d828a2a88 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ i'll lyk when its done | **Feature** | **Status** | Gat Status | **Comments** | |--------------------------------|-----------------------------|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| | Transaction pooling | :white_check_mark: | :white_check_mark: | Identical to PgBouncer. | -| Session pooling | :white_check_mark: | no (do we want?) | Identical to PgBouncer. | +| Session pooling | :white_check_mark: | :white_check_mark: | Identical to PgBouncer. | | `COPY` support | :white_check_mark: | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. | | Query cancellation | :white_check_mark: | :white_check_mark: | Supported both in transaction and session pooling modes. | | Load balancing of read queries | :white_check_mark: | :white_check_mark: | Using random between replicas. Primary is included when `primary_reads_enabled` is enabled (default). | diff --git a/lib/gat/admin/admin.go b/lib/gat/admin/admin.go index 3b6af16110fa3e8e766e239b55cc96cae3ba16f5..dd91b0c1a7c07c4526477a5a1cc24ba07bbb396e 100644 --- a/lib/gat/admin/admin.go +++ b/lib/gat/admin/admin.go @@ -10,7 +10,7 @@ import ( "time" ) -// The admin database, implemented through the gat.Pool interface, allowing it to be added to any existing Gat +// The admin database, implemented through the gat.Database interface, allowing it to be added to any existing Gat import ( "context" @@ -74,19 +74,19 @@ func getAdminUser(g gat.Gat) *config.User { } } -type Pool struct { +type Database struct { gat gat.Gat - connPool *ConnectionPool + connPool *Pool r cmux.Mux[gat.Client, error] } -func NewPool(g gat.Gat) *Pool { - out := &Pool{ +func New(g gat.Gat) *Database { + out := &Database{ gat: g, } - out.connPool = &ConnectionPool{ - pool: out, + out.connPool = &Pool{ + database: out, } out.r = cmux.NewMapMux[gat.Client, error]() out.r.Register([]string{"show", "stats_totals"}, func(client gat.Client, _ []string) error { @@ -179,7 +179,7 @@ func NewPool(g gat.Gat) *Pool { return out } -func (p *Pool) showStats(client gat.Client, totals, averages bool) error { +func (p *Database) showStats(client gat.Client, totals, averages bool) error { rowDesc := new(protocol.RowDescription) rowDesc.Fields.Fields = []protocol.FieldsRowDescriptionFields{ { @@ -283,7 +283,7 @@ func (p *Pool) showStats(client gat.Client, totals, averages bool) error { if err != nil { return err } - for name, pl := range p.gat.GetPools() { + for name, pl := range p.gat.GetDatabases() { stats := pl.GetStats() if stats == nil { continue @@ -350,7 +350,7 @@ func (p *Pool) showStats(client gat.Client, totals, averages bool) error { return nil } -func (p *Pool) showTotals(client gat.Client) error { +func (p *Database) showTotals(client gat.Client) error { rowDesc := new(protocol.RowDescription) rowDesc.Fields.Fields = []protocol.FieldsRowDescriptionFields{ { @@ -446,7 +446,7 @@ func (p *Pool) showTotals(client gat.Client) error { var totalXactCount, totalQueryCount, totalWaitCount, totalReceived, totalSent, totalXactTime, totalQueryTime, totalWaitTime int64 var alive time.Duration - for _, pl := range p.gat.GetPools() { + for _, pl := range p.gat.GetDatabases() { stats := pl.GetStats() if stats == nil { continue @@ -523,7 +523,7 @@ func (p *Pool) showTotals(client gat.Client) error { return client.Send(row) } -func (p *Pool) GetUser(name string) *config.User { +func (p *Database) GetUser(name string) *config.User { u := getAdminUser(p.gat) if name != u.Name { return nil @@ -531,11 +531,11 @@ func (p *Pool) GetUser(name string) *config.User { return u } -func (p *Pool) GetRouter() gat.QueryRouter { +func (p *Database) GetRouter() gat.QueryRouter { return nil } -func (p *Pool) WithUser(name string) gat.ConnectionPool { +func (p *Database) WithUser(name string) gat.Pool { conf := p.gat.GetConfig() if name != conf.General.AdminUsername { return nil @@ -543,56 +543,53 @@ func (p *Pool) WithUser(name string) gat.ConnectionPool { return p.connPool } -func (p *Pool) ConnectionPools() []gat.ConnectionPool { - return []gat.ConnectionPool{ +func (p *Database) GetPools() []gat.Pool { + return []gat.Pool{ p.connPool, } } -func (p *Pool) GetStats() *gat.PoolStats { +func (p *Database) GetStats() *gat.PoolStats { return nil } -func (p *Pool) EnsureConfig(c *config.Pool) { +func (p *Database) EnsureConfig(c *config.Pool) { // TODO } -var _ gat.Pool = (*Pool)(nil) - -type ConnectionPool struct { - pool *Pool -} +var _ gat.Database = (*Database)(nil) -func (c *ConnectionPool) GetUser() *config.User { - return getAdminUser(c.pool.gat) +type Pool struct { + database *Database } -func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus { - return getServerInfo(c.pool.gat) +func (c *Pool) GetUser() *config.User { + return getAdminUser(c.database.gat) } -func (c *ConnectionPool) GetPool() gat.Pool { - return c.pool +func (c *Pool) GetServerInfo() []*protocol.ParameterStatus { + return getServerInfo(c.database.gat) } -func (c *ConnectionPool) GetShards() []gat.Shard { - // this db is within gat, there are no shards - return nil +func (c *Pool) GetDatabase() gat.Database { + return c.database } -func (c *ConnectionPool) EnsureConfig(conf *config.Pool) { +func (c *Pool) EnsureConfig(conf *config.Pool) { // TODO } -func (c *ConnectionPool) Describe(ctx context.Context, client gat.Client, describe *protocol.Describe) error { +func (c *Pool) OnDisconnect(_ gat.Client) {} + +func (c *Pool) Describe(ctx context.Context, client gat.Client, describe *protocol.Describe) error { return errors.New("describe not implemented") } -func (c *ConnectionPool) Execute(ctx context.Context, client gat.Client, execute *protocol.Execute) error { +func (c *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol.Execute) error { return errors.New("execute not implemented") } -func (c *ConnectionPool) SimpleQuery(ctx context.Context, client gat.Client, query string) error { +func (c *Pool) SimpleQuery(ctx context.Context, client gat.Client, query string) error { parsed, err := parse.Parse(query) if err != nil { return err @@ -602,7 +599,7 @@ func (c *ConnectionPool) SimpleQuery(ctx context.Context, client gat.Client, que } for _, cmd := range parsed { var matched bool - err, matched = c.pool.r.Call(client, append([]string{cmd.Command}, cmd.Arguments...)) + err, matched = c.database.r.Call(client, append([]string{cmd.Command}, cmd.Arguments...)) if !matched { return errors.New("unknown command") } @@ -619,12 +616,12 @@ func (c *ConnectionPool) SimpleQuery(ctx context.Context, client gat.Client, que return nil } -func (c *ConnectionPool) Transaction(ctx context.Context, client gat.Client, query string) error { +func (c *Pool) Transaction(ctx context.Context, client gat.Client, query string) error { return errors.New("transactions not implemented") } -func (c *ConnectionPool) CallFunction(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { +func (c *Pool) CallFunction(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { return errors.New("functions not implemented") } -var _ gat.ConnectionPool = (*ConnectionPool)(nil) +var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/gatling/pool/pool.go b/lib/gat/database/database.go similarity index 57% rename from lib/gat/gatling/pool/pool.go rename to lib/gat/database/database.go index 29e4932177907e8e9430e331a02c29ab75a53b1d..be9de25ae9278d64fe9a23ba7b0fe44050399e07 100644 --- a/lib/gat/gatling/pool/pool.go +++ b/lib/gat/database/database.go @@ -1,37 +1,41 @@ -package pool +package database import ( - "gfx.cafe/gfx/pggat/lib/gat/gatling/pool/conn_pool" - "gfx.cafe/gfx/pggat/lib/gat/gatling/pool/query_router" + "gfx.cafe/gfx/pggat/lib/gat/database/query_router" + "gfx.cafe/gfx/pggat/lib/gat/pool/session" "sync" "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat" ) -type Pool struct { +type Database struct { c *config.Pool users map[string]config.User - connPools map[string]gat.ConnectionPool + connPools map[string]gat.Pool stats *gat.PoolStats router *query_router.QueryRouter + dialer gat.Dialer + mu sync.RWMutex } -func NewPool(conf *config.Pool) *Pool { - pool := &Pool{ - connPools: make(map[string]gat.ConnectionPool), +func New(dialer gat.Dialer, conf *config.Pool) *Database { + pool := &Database{ + connPools: make(map[string]gat.Pool), stats: gat.NewPoolStats(), router: query_router.DefaultRouter, + + dialer: dialer, } pool.EnsureConfig(conf) return pool } -func (p *Pool) EnsureConfig(conf *config.Pool) { +func (p *Database) EnsureConfig(conf *config.Pool) { p.mu.Lock() defer p.mu.Unlock() p.c = conf @@ -45,12 +49,12 @@ func (p *Pool) EnsureConfig(conf *config.Pool) { existing.EnsureConfig(conf) } else { u := user - p.connPools[name] = conn_pool.NewConnectionPool(p, conf, &u) + p.connPools[name] = session.New(p, p.dialer, conf, &u) } } } -func (p *Pool) GetUser(name string) *config.User { +func (p *Database) GetUser(name string) *config.User { p.mu.RLock() defer p.mu.RUnlock() user, ok := p.users[name] @@ -60,11 +64,11 @@ func (p *Pool) GetUser(name string) *config.User { return &user } -func (p *Pool) GetRouter() gat.QueryRouter { +func (p *Database) GetRouter() gat.QueryRouter { return p.router } -func (p *Pool) WithUser(name string) gat.ConnectionPool { +func (p *Database) WithUser(name string) gat.Pool { p.mu.RLock() defer p.mu.RUnlock() pool, ok := p.connPools[name] @@ -74,10 +78,10 @@ func (p *Pool) WithUser(name string) gat.ConnectionPool { return pool } -func (p *Pool) ConnectionPools() []gat.ConnectionPool { +func (p *Database) GetPools() []gat.Pool { p.mu.RLock() defer p.mu.RUnlock() - out := make([]gat.ConnectionPool, len(p.connPools)) + out := make([]gat.Pool, len(p.connPools)) idx := 0 for _, v := range p.connPools { out[idx] = v @@ -86,8 +90,8 @@ func (p *Pool) ConnectionPools() []gat.ConnectionPool { return out } -func (p *Pool) GetStats() *gat.PoolStats { +func (p *Database) GetStats() *gat.PoolStats { return p.stats } -var _ gat.Pool = (*Pool)(nil) +var _ gat.Database = (*Database)(nil) diff --git a/lib/gat/gatling/pool/pool_test.go b/lib/gat/database/database_test.go similarity index 72% rename from lib/gat/gatling/pool/pool_test.go rename to lib/gat/database/database_test.go index 9fbebf403745c0851235fc49cd6ff7b369f1d93f..e2f1b03a90cb076a0edd2b830b737cc0baa1655e 100644 --- a/lib/gat/gatling/pool/pool_test.go +++ b/lib/gat/database/database_test.go @@ -1,3 +1,3 @@ -package pool +package database // TODO: no tests, we need to write our own diff --git a/lib/gat/gatling/pool/query_router/query_router.go b/lib/gat/database/query_router/query_router.go similarity index 100% rename from lib/gat/gatling/pool/query_router/query_router.go rename to lib/gat/database/query_router/query_router.go diff --git a/lib/gat/gatling/pool/query_router/query_router_test.go b/lib/gat/database/query_router/query_router_test.go similarity index 100% rename from lib/gat/gatling/pool/query_router/query_router_test.go rename to lib/gat/database/query_router/query_router_test.go diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 727df9c1c43dd7c283d584d6ecd3353cb5228c47..79ab36d12f689e13565955f365e4f8bfbddc966e 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -76,7 +76,7 @@ type Client struct { requestTime time.Time connectTime time.Time - server gat.ConnectionPool + server gat.Pool poolName string username string @@ -128,7 +128,7 @@ func (c *Client) GetRemotePid() int { return int(c.pid) } -func (c *Client) GetConnectionPool() gat.ConnectionPool { +func (c *Client) GetConnectionPool() gat.Pool { return c.server } @@ -181,7 +181,7 @@ func NewClient( return c } -func (c *Client) Id() gat.ClientID { +func (c *Client) GetId() gat.ClientID { return gat.ClientID{ PID: c.pid, SecretKey: c.secretKey, @@ -325,7 +325,7 @@ func (c *Client) Accept(ctx context.Context) error { } } - pool := c.gatling.GetPool(c.poolName) + pool := c.gatling.GetDatabase(c.poolName) if pool == nil { return fmt.Errorf("pool '%s' not found", c.poolName) } @@ -362,6 +362,7 @@ func (c *Client) Accept(ctx context.Context) error { if c.server == nil { return fmt.Errorf("no pool for '%s'", c.username) } + defer c.server.OnDisconnect(c) authOk := new(protocol.Authentication) authOk.Fields.Code = 0 @@ -400,7 +401,7 @@ func (c *Client) Accept(ctx context.Context) error { } open, err = c.tick(ctx) // add send and recv to pool - stats := c.server.GetPool().GetStats() + stats := c.server.GetDatabase().GetStats() if stats != nil { stats.AddTotalSent(c.wr.BytesWritten.Swap(0)) stats.AddTotalReceived(c.r.BytesRead.Swap(0)) diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index 115276f0a912474df8aec4c2a56f1f9dd3852bd2..5187f13cc2d22daa1dd058ebbb637ee5da8de7f8 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "gfx.cafe/gfx/pggat/lib/gat/admin" + "gfx.cafe/gfx/pggat/lib/gat/database" + "gfx.cafe/gfx/pggat/lib/gat/gatling/server" "io" "net" "sync" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/gatling/client" - "gfx.cafe/gfx/pggat/lib/gat/gatling/pool" "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" "git.tuxpa.in/a/zlog/log" @@ -26,18 +27,18 @@ type Gatling struct { // channel that new config are delivered chConfig chan *config.Global - pools map[string]gat.Pool + pools map[string]gat.Database clients map[gat.ClientID]gat.Client } func NewGatling(conf *config.Global) *Gatling { g := &Gatling{ chConfig: make(chan *config.Global, 1), - pools: make(map[string]gat.Pool), + pools: make(map[string]gat.Database), clients: make(map[gat.ClientID]gat.Client), } // add admin pool - adminPool := admin.NewPool(g) + adminPool := admin.New(g) g.pools["pgbouncer"] = adminPool g.pools["pggat"] = adminPool @@ -67,7 +68,7 @@ func (g *Gatling) GetConfig() *config.Global { return g.c } -func (g *Gatling) GetPool(name string) gat.Pool { +func (g *Gatling) GetDatabase(name string) gat.Database { g.mu.RLock() defer g.mu.RUnlock() srv, ok := g.pools[name] @@ -77,7 +78,7 @@ func (g *Gatling) GetPool(name string) gat.Pool { return srv } -func (g *Gatling) GetPools() map[string]gat.Pool { +func (g *Gatling) GetDatabases() map[string]gat.Database { g.mu.RLock() defer g.mu.RUnlock() return g.pools @@ -139,7 +140,7 @@ func (g *Gatling) ensurePools(c *config.Global) error { if existing, ok := g.pools[name]; ok { existing.EnsureConfig(p) } else { - g.pools[name] = pool.NewPool(p) + g.pools[name] = database.New(server.Dial, p) } } return nil @@ -178,12 +179,12 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { func() { g.mu.Lock() defer g.mu.Unlock() - g.clients[cl.Id()] = cl + g.clients[cl.GetId()] = cl }() defer func() { g.mu.Lock() defer g.mu.Unlock() - delete(g.clients, cl.Id()) + delete(g.clients, cl.GetId()) }() err := cl.Accept(ctx) diff --git a/lib/gat/gatling/pool/conn_pool/conn_pool.go b/lib/gat/gatling/pool/conn_pool/conn_pool.go deleted file mode 100644 index 3a551e987c8bd2e7bfc375057d05b78fbcac1f70..0000000000000000000000000000000000000000 --- a/lib/gat/gatling/pool/conn_pool/conn_pool.go +++ /dev/null @@ -1,102 +0,0 @@ -package conn_pool - -import ( - "context" - "gfx.cafe/gfx/pggat/lib/config" - "gfx.cafe/gfx/pggat/lib/gat" - "gfx.cafe/gfx/pggat/lib/gat/protocol" - "runtime" - "sync/atomic" - "time" -) - -type ConnectionPool struct { - // the pool connection - c atomic.Pointer[config.Pool] - user *config.User - pool gat.Pool - workerCount atomic.Int64 - - // see: https://github.com/golang/go/blob/master/src/runtime/chan.go#L33 - // channels are a thread safe ring buffer implemented via a linked list of goroutines. - // the idea is that goroutines are cheap, and we can afford to have one per pending request. - // there is no real reason to implement a complicated worker pool pattern when well, if we're okay with having a 2-4kb overhead per request, then this is fine. trading space for code complexity - workerPool chan *worker -} - -func NewConnectionPool(pool gat.Pool, conf *config.Pool, user *config.User) *ConnectionPool { - p := &ConnectionPool{ - user: user, - pool: pool, - workerPool: make(chan *worker, 1+runtime.NumCPU()*4), - } - p.EnsureConfig(conf) - return p -} - -func (c *ConnectionPool) GetPool() gat.Pool { - return c.pool -} - -func (c *ConnectionPool) getWorker() *worker { - start := time.Now() - defer func() { - c.pool.GetStats().AddWaitTime(time.Now().Sub(start).Microseconds()) - }() - select { - case w := <-c.workerPool: - return w - default: - if c.workerCount.Add(1)-1 < int64(c.user.PoolSize) { - next := &worker{ - w: c, - } - return next - } else { - w := <-c.workerPool - return w - } - } -} - -func (c *ConnectionPool) EnsureConfig(conf *config.Pool) { - c.c.Store(conf) -} - -func (c *ConnectionPool) GetUser() *config.User { - return c.user -} - -func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus { - return c.getWorker().GetServerInfo() -} - -func (c *ConnectionPool) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { - return c.getWorker().HandleDescribe(ctx, client, d) -} - -func (c *ConnectionPool) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { - return c.getWorker().HandleExecute(ctx, client, e) -} - -func (c *ConnectionPool) SimpleQuery(ctx context.Context, client gat.Client, q string) error { - // see if the pool router can handle it - handled, err := c.pool.GetRouter().TryHandle(client, q) - if err != nil { - return err - } - if handled { - return nil - } - return c.getWorker().HandleSimpleQuery(ctx, client, q) -} - -func (c *ConnectionPool) Transaction(ctx context.Context, client gat.Client, q string) error { - return c.getWorker().HandleTransaction(ctx, client, q) -} - -func (c *ConnectionPool) CallFunction(ctx context.Context, client gat.Client, f *protocol.FunctionCall) error { - return c.getWorker().HandleFunction(ctx, client, f) -} - -var _ gat.ConnectionPool = (*ConnectionPool)(nil) diff --git a/lib/gat/gatling/pool/conn_pool/shard/server/server.go b/lib/gat/gatling/server/server.go similarity index 97% rename from lib/gat/gatling/pool/conn_pool/shard/server/server.go rename to lib/gat/gatling/server/server.go index c2b77f1a89624f80b97d9a7b37ffc2a8c8b03ac8..fc0813d3c05c83ce1a43a89808c9e45fe0f0cca7 100644 --- a/lib/gat/gatling/pool/conn_pool/shard/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -57,38 +57,33 @@ type Server struct { mu sync.Mutex } -func Dial(ctx context.Context, - addr string, - port uint16, - user *config.User, - db string, dbuser string, dbpass string, -) (*Server, error) { +func Dial(ctx context.Context, user *config.User, shard *config.Shard, server *config.Server) (gat.Connection, error) { s := &Server{ - addr: addr, - port: port, + addr: server.Host, + port: server.Port, state: gat.ConnectionNew, boundPreparedStatments: make(map[string]*protocol.Parse), boundPortals: make(map[string]*protocol.Bind), - dbuser: dbuser, - dbpass: dbpass, + dbuser: server.Username, + dbpass: server.Password, } var err error - s.conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + s.conn, err = net.Dial("tcp", fmt.Sprintf("%s:%d", server.Host, server.Port)) if err != nil { return nil, err } s.r = bufio.NewReader(s.conn) s.wr = bufio.NewWriter(s.conn) s.user = *user - s.db = db + s.db = shard.Database s.log = log.With(). Stringer("addr", s.conn.RemoteAddr()). Str("user", user.Name). - Str("db", db). + Str("db", shard.Database). Logger() return s, s.connect(ctx) } diff --git a/lib/gat/gatling/pool/conn_pool/shard/server/server_test.go b/lib/gat/gatling/server/server_test.go similarity index 100% rename from lib/gat/gatling/pool/conn_pool/shard/server/server_test.go rename to lib/gat/gatling/server/server_test.go diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index 35a3ec7d0a66e2c62d597e9b6629ed043463687d..6ded9dce9cb926690b9a963b3b1b037b57bcc5c0 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -21,12 +21,14 @@ const ( ) type Client interface { + GetId() ClientID + GetPreparedStatement(name string) *protocol.Parse GetPortal(name string) *protocol.Bind GetCurrentConn() Connection SetCurrentConn(conn Connection) - GetConnectionPool() ConnectionPool + GetConnectionPool() Pool GetState() ClientState GetAddress() net.Addr @@ -50,18 +52,18 @@ type Client interface { type Gat interface { GetVersion() string GetConfig() *config.Global - GetPool(name string) Pool - GetPools() map[string]Pool + GetDatabase(name string) Database + GetDatabases() map[string]Database GetClient(id ClientID) Client GetClients() []Client } -type Pool interface { +type Database interface { GetUser(name string) *config.User GetRouter() QueryRouter - WithUser(name string) ConnectionPool - ConnectionPools() []ConnectionPool + WithUser(name string) Pool + GetPools() []Pool GetStats() *PoolStats @@ -74,14 +76,16 @@ type QueryRouter interface { TryHandle(client Client, query string) (bool, error) } -type ConnectionPool interface { +type Pool interface { GetUser() *config.User GetServerInfo() []*protocol.ParameterStatus - GetPool() Pool + GetDatabase() Database EnsureConfig(c *config.Pool) + OnDisconnect(client Client) + // extended queries Describe(ctx context.Context, client Client, describe *protocol.Describe) error Execute(ctx context.Context, client Client, execute *protocol.Execute) error @@ -92,14 +96,6 @@ type ConnectionPool interface { CallFunction(ctx context.Context, client Client, payload *protocol.FunctionCall) error } -type Shard interface { - GetPrimary() Connection - GetReplicas() []Connection - Choose(role config.ServerRole) Connection - - EnsureConfig(c *config.Shard) -} - type ConnectionState string const ( @@ -110,6 +106,8 @@ const ( ConnectionNew = "new" ) +type Dialer = func(context.Context, *config.User, *config.Shard, *config.Server) (Connection, error) + type Connection interface { GetServerInfo() []*protocol.ParameterStatus diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..abf652b12d09c335524a23041e9db1f3caa02c36 --- /dev/null +++ b/lib/gat/pool/session/pool.go @@ -0,0 +1,113 @@ +package session + +import ( + "context" + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/protocol" + "runtime" + "sync/atomic" +) + +type Pool struct { + c atomic.Pointer[config.Pool] + user *config.User + database gat.Database + + dialer gat.Dialer + + assigned map[gat.ClientID]gat.Connection + + servers chan gat.Connection +} + +func New(database gat.Database, dialer gat.Dialer, conf *config.Pool, user *config.User) *Pool { + p := &Pool{ + user: user, + database: database, + + dialer: dialer, + + assigned: make(map[gat.ClientID]gat.Connection), + + servers: make(chan gat.Connection, 1+runtime.NumCPU()*4), + } + p.EnsureConfig(conf) + return p +} + +func (p *Pool) getConnection() gat.Connection { + select { + case c := <-p.servers: + return c + default: + shard := p.c.Load().Shards[0] + s, _ := p.dialer(context.TODO(), p.user, shard, shard.Servers[0]) + return s + } +} + +func (p *Pool) returnConnection(c gat.Connection) { + p.servers <- c +} + +func (p *Pool) getOrAssign(client gat.Client) gat.Connection { + cid := client.GetId() + c, ok := p.assigned[cid] + if !ok { + get := p.getConnection() + p.assigned[cid] = get + return get + } + return c +} + +func (p *Pool) GetDatabase() gat.Database { + return p.database +} + +func (p *Pool) EnsureConfig(c *config.Pool) { + p.c.Store(c) +} + +func (p *Pool) OnDisconnect(client gat.Client) { + cid := client.GetId() + c, ok := p.assigned[cid] + if !ok { + return + } + delete(p.assigned, cid) + p.servers <- c +} + +func (p *Pool) GetUser() *config.User { + return p.user +} + +func (p *Pool) GetServerInfo() []*protocol.ParameterStatus { + c := p.getConnection() + defer p.returnConnection(c) + return c.GetServerInfo() +} + +func (p *Pool) Describe(ctx context.Context, client gat.Client, describe *protocol.Describe) error { + return p.getOrAssign(client).Describe(client, describe) +} + +func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol.Execute) error { + return p.getOrAssign(client).Execute(client, execute) +} + +func (p *Pool) SimpleQuery(ctx context.Context, client gat.Client, query string) error { + return p.getOrAssign(client).SimpleQuery(ctx, client, query) +} + +func (p *Pool) Transaction(ctx context.Context, client gat.Client, query string) error { + return p.getOrAssign(client).Transaction(ctx, client, query) +} + +func (p *Pool) CallFunction(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { + return p.getOrAssign(client).CallFunction(client, payload) +} + +var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/pool/transaction/pool.go b/lib/gat/pool/transaction/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..ca14ae306716588fb04d9afd9098297e17b3af77 --- /dev/null +++ b/lib/gat/pool/transaction/pool.go @@ -0,0 +1,107 @@ +package transaction + +import ( + "context" + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/protocol" + "runtime" + "sync/atomic" + "time" +) + +type Pool struct { + // the database connection + c atomic.Pointer[config.Pool] + user *config.User + database gat.Database + workerCount atomic.Int64 + + dialer gat.Dialer + + // see: https://github.com/golang/go/blob/master/src/runtime/chan.go#L33 + // channels are a thread safe ring buffer implemented via a linked list of goroutines. + // the idea is that goroutines are cheap, and we can afford to have one per pending request. + // there is no real reason to implement a complicated worker database pattern when well, if we're okay with having a 2-4kb overhead per request, then this is fine. trading space for code complexity + workerPool chan *worker +} + +func New(database gat.Database, dialer gat.Dialer, conf *config.Pool, user *config.User) *Pool { + p := &Pool{ + user: user, + database: database, + dialer: dialer, + workerPool: make(chan *worker, 1+runtime.NumCPU()*4), + } + p.EnsureConfig(conf) + return p +} + +func (c *Pool) GetDatabase() gat.Database { + return c.database +} + +func (c *Pool) getWorker() *worker { + start := time.Now() + defer func() { + c.database.GetStats().AddWaitTime(time.Now().Sub(start).Microseconds()) + }() + select { + case w := <-c.workerPool: + return w + default: + if c.workerCount.Add(1)-1 < int64(c.user.PoolSize) { + next := &worker{ + w: c, + } + return next + } else { + w := <-c.workerPool + return w + } + } +} + +func (c *Pool) EnsureConfig(conf *config.Pool) { + c.c.Store(conf) +} + +func (c *Pool) OnDisconnect(_ gat.Client) {} + +func (c *Pool) GetUser() *config.User { + return c.user +} + +func (c *Pool) GetServerInfo() []*protocol.ParameterStatus { + return c.getWorker().GetServerInfo() +} + +func (c *Pool) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { + return c.getWorker().HandleDescribe(ctx, client, d) +} + +func (c *Pool) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { + return c.getWorker().HandleExecute(ctx, client, e) +} + +func (c *Pool) SimpleQuery(ctx context.Context, client gat.Client, q string) error { + // see if the database router can handle it + handled, err := c.database.GetRouter().TryHandle(client, q) + if err != nil { + return err + } + if handled { + return nil + } + return c.getWorker().HandleSimpleQuery(ctx, client, q) +} + +func (c *Pool) Transaction(ctx context.Context, client gat.Client, q string) error { + return c.getWorker().HandleTransaction(ctx, client, q) +} + +func (c *Pool) CallFunction(ctx context.Context, client gat.Client, f *protocol.FunctionCall) error { + return c.getWorker().HandleFunction(ctx, client, f) +} + +var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/gatling/pool/conn_pool/shard/shard.go b/lib/gat/pool/transaction/shard/shard.go similarity index 80% rename from lib/gat/gatling/pool/conn_pool/shard/shard.go rename to lib/gat/pool/transaction/shard/shard.go index a3bb64ad60fe467efe3c8057c94a3978cbee270b..05f440a202982a7d7bd93af0a6eec2d0fff045ec 100644 --- a/lib/gat/gatling/pool/conn_pool/shard/shard.go +++ b/lib/gat/pool/transaction/shard/shard.go @@ -4,7 +4,6 @@ import ( "context" "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat" - "gfx.cafe/gfx/pggat/lib/gat/gatling/pool/conn_pool/shard/server" "math/rand" "reflect" ) @@ -15,12 +14,16 @@ type Shard struct { user *config.User conf *config.Shard + + dialer gat.Dialer } -func FromConfig(user *config.User, conf *config.Shard) *Shard { +func FromConfig(dialer gat.Dialer, user *config.User, conf *config.Shard) *Shard { out := &Shard{ user: user, conf: conf, + + dialer: dialer, } out.init() return out @@ -30,7 +33,7 @@ func (s *Shard) init() { s.primary = nil s.replicas = nil for _, serv := range s.conf.Servers { - srv, err := server.Dial(context.TODO(), serv.Host, serv.Port, s.user, s.conf.Database, serv.Username, serv.Password) + srv, err := s.dialer(context.TODO(), s.user, s.conf, serv) if err != nil { continue } @@ -72,5 +75,3 @@ func (s *Shard) EnsureConfig(c *config.Shard) { s.init() } } - -var _ gat.Shard = (*Shard)(nil) diff --git a/lib/gat/gatling/pool/conn_pool/worker.go b/lib/gat/pool/transaction/worker.go similarity index 90% rename from lib/gat/gatling/pool/conn_pool/worker.go rename to lib/gat/pool/transaction/worker.go index f20a3814561419f6f2f8011d26f9ed81b38400fa..92e5ae84abbb78fcc5f99228e485154555d9aeff 100644 --- a/lib/gat/gatling/pool/conn_pool/worker.go +++ b/lib/gat/pool/transaction/worker.go @@ -1,11 +1,11 @@ -package conn_pool +package transaction import ( "context" "fmt" "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat" - "gfx.cafe/gfx/pggat/lib/gat/gatling/pool/conn_pool/shard" + "gfx.cafe/gfx/pggat/lib/gat/pool/transaction/shard" "gfx.cafe/gfx/pggat/lib/gat/protocol" "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" "math/rand" @@ -13,19 +13,19 @@ import ( "time" ) -// a single use worker with an embedded connection pool. -// it wraps a pointer to the connection pool. +// a single use worker with an embedded connection database. +// it wraps a pointer to the connection database. type worker struct { - // the parent connectino pool - w *ConnectionPool + // the parent connectino database + w *Pool rev int - shards []gat.Shard + shards []*shard.Shard mu sync.Mutex } -// ret urn worker to pool +// ret urn worker to database func (w *worker) ret() { w.w.workerPool <- w } @@ -41,7 +41,7 @@ func (w *worker) fetchShard(n int) bool { w.shards = append(w.shards, nil) } - w.shards[n] = shard.FromConfig(w.w.user, conf.Shards[n]) + w.shards[n] = shard.FromConfig(w.w.dialer, w.w.user, conf.Shards[n]) return true } @@ -52,7 +52,7 @@ func (w *worker) invalidateShard(n int) { w.shards[n] = nil } -func (w *worker) chooseShard(client gat.Client) gat.Shard { +func (w *worker) chooseShard(client gat.Client) *shard.Shard { w.mu.Lock() defer w.mu.Unlock() @@ -164,7 +164,7 @@ func (w *worker) HandleSimpleQuery(ctx context.Context, c gat.Client, query stri start := time.Now() defer func() { - w.w.pool.GetStats().AddQueryTime(time.Now().Sub(start).Microseconds()) + w.w.database.GetStats().AddQueryTime(time.Now().Sub(start).Microseconds()) }() errch := make(chan error) @@ -190,7 +190,7 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri start := time.Now() defer func() { - w.w.pool.GetStats().AddXactTime(time.Now().Sub(start).Microseconds()) + w.w.database.GetStats().AddXactTime(time.Now().Sub(start).Microseconds()) }() errch := make(chan error) @@ -262,7 +262,7 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p } } - which, err := w.w.pool.GetRouter().InferRole(ps.Fields.Query) + which, err := w.w.database.GetRouter().InferRole(ps.Fields.Query) if err != nil { return err } @@ -299,7 +299,7 @@ func (w *worker) z_actually_do_simple_query(ctx context.Context, client gat.Clie return fmt.Errorf("call to query '%s' failed", payload) } // run the query on the server - which, err := w.w.pool.GetRouter().InferRole(payload) + which, err := w.w.database.GetRouter().InferRole(payload) if err != nil { return fmt.Errorf("error parsing '%s': %w", payload, err) } @@ -324,7 +324,7 @@ func (w *worker) z_actually_do_transaction(ctx context.Context, client gat.Clien return fmt.Errorf("call to transaction '%s' failed", payload) } // run the query on the server - which, err := w.w.pool.GetRouter().InferRole(payload) + which, err := w.w.database.GetRouter().InferRole(payload) if err != nil { return fmt.Errorf("error parsing '%s': %w", payload, err) }