diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 078a1b07e3a19c4a5f6a68884ea62684dd90e3ae..d73828e6fe95e158c10577c7615755170ba77754 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -2,13 +2,14 @@ package main import ( "context" + "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat/gatling" "git.tuxpa.in/a/zlog/log" ) // test config, should be changed -const CONFIG = "./config_data.toml" +const CONFIG = "./config_data.yml" func main() { conf, err := config.Load(CONFIG) @@ -16,7 +17,6 @@ func main() { panic(err) } g := gatling.NewGatling(conf) - log.Println("listening on port", conf.General.Port) err = g.ListenAndServe(context.Background()) if err != nil { diff --git a/config_data.yml b/config_data.yml new file mode 100644 index 0000000000000000000000000000000000000000..65148d671b0915f2fca30d15ca5a1ea47c80368c --- /dev/null +++ b/config_data.yml @@ -0,0 +1,96 @@ +general: + host: 0.0.0.0 + port: 6432 + enable_prometheus_exporter: true + prometheus_exporter_port: 9090 + connect_timeout: 5000 + healthcheck_timeout: 1000 + healthcheck_delay: 30000 + shutdown_timeout: 60000 + ban_time: 60 + autoreload: false + admin_username: postgres + admin_password: postgres +pools: + simple_db: + pool_mode: session + default_role: primary + query_parser_enabled: true + primary_reads_enabled: true + sharding_function: pg_bigint_hash + users: + - username: postgres + password: postgres + role: reader + pool_size: 5 + statement_timeout: 0 + - username: postgres + password: password + role: writer + pool_size: 5 + statement_timeout: 0 + shards: + - pool_size: 5 + statement_timeout: 0 + database: postgres + servers: + - host: localhost + port: 5432 + role: primary + username: postgres + password: postgres + - host: localhost + role: replica + port: 5432 + username: postgres + password: postgres + sharded: + pool_mode: transaction + default_role: any + query_parser_enabled: true + primary_reads_enabled: true + sharding_function: pg_bigint_hash + users: + - username: postgres + password: postgres + pool_size: 9 + statement_timeout: 0 + - username: postgres + password: postgres + pool_size: 21 + statement_timeout: 15000 + shards: + - database: postgres + servers: + - host: localhost + port: 5432 + role: primary + username: postgres + password: postgres + - host: localhost + role: replica + port: 5432 + username: postgres + password: postgres + - database: postgres + servers: + - host: localhost + port: 5432 + role: primary + - host: localhost + role: replica + port: 5432 + username: postgres + password: postgres + - database: postgres + servers: + - host: localhost + port: 5432 + role: primary + username: postgres + password: postgres + - host: localhost + role: replica + port: 5432 + username: postgres + password: postgres diff --git a/go.mod b/go.mod index 5b9c52983b6015b4363acd208754a4df43bc7675..feff4852ac17ab873e632cbde4440e7dc7b25402 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/xdg-go/scram v1.1.1 golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( diff --git a/go.sum b/go.sum index a66211542b6ffec50500ead79b713358089d615f..a23ece5ff68027e7542bcb55a829eacdf0b89c4e 100644 --- a/go.sum +++ b/go.sum @@ -383,5 +383,6 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/lib/config/config.go b/lib/config/config.go index 27a284cbb6cbd59255ba53573476bb90b94f237b..978a80a01f8375dd24066864cb60765587796a66 100644 --- a/lib/config/config.go +++ b/lib/config/config.go @@ -2,8 +2,10 @@ package config import ( "os" + "path/filepath" "github.com/BurntSushi/toml" + "gopkg.in/yaml.v3" ) type PoolMode string @@ -21,9 +23,41 @@ const ( SERVERROLE_NONE ServerRole = "NONE" ) +type UserRole string + +const ( + USERROLE_ADMIN ServerRole = "admin" + USERROLE_WRITER ServerRole = "writer" + USERROLE_READER ServerRole = "reader" +) + +func Load(path string) (*Global, error) { + var g Global + ext := filepath.Ext(path) + file, err := os.ReadFile(path) + if err != nil { + return nil, err + } + switch ext { + case "toml": + err := toml.Unmarshal(file, &g) + if err != nil { + return nil, err + } + case "yml", "yaml", "json": + fallthrough + default: + err := yaml.Unmarshal(file, &g) + if err != nil { + return nil, err + } + } + return &g, nil +} + type Global struct { - General General `toml:"general" yaml:"general" json:"general"` - Pools map[string]Pool `toml:"pools" yaml:"pools" json:"pools"` + General General `toml:"general" yaml:"general" json:"general"` + Pools map[string]*Pool `toml:"pools" yaml:"pools" json:"pools"` } type General struct { @@ -59,56 +93,29 @@ type Pool struct { PrimaryReadsEnabled bool `toml:"primary_reads_enabled" yaml:"primary_reads_enabled" json:"primary_reads_enabled"` ShardingFunction string `toml:"sharding_function" yaml:"sharding_function" json:"sharding_function"` - Shards map[string]Shard `toml:"shards" yaml:"shards" json:"shards"` - Users map[string]User `toml:"users" yaml:"users" json:"users"` + Shards []*Shard `toml:"shards" yaml:"shards" json:"shards"` + Users []*User `toml:"users" yaml:"users" json:"users"` } type User struct { - Name string `toml:"username" yaml:"name" json:"name"` - Password string `toml:"password" yaml:"password" json:"password"` - PoolSize int `toml:"pool_size" yaml:"pool_size" json:"pool_size"` - StatementTimeout int `toml:"statement_timeout" yaml:"statement_timeout" json:"statement_timeout"` -} + Name string `toml:"username" yaml:"username" json:"username"` + Password string `toml:"password" yaml:"password" json:"password"` -type Shard struct { - Database string `toml:"database" yaml:"database" json:"database"` - Servers []Server `toml:"servers" yaml:"servers" json:"servers"` -} - -type Server [3]any - -func (o Server) Host() string { - if v, ok := o[0].(string); ok { - return v - } - return "" + Role UserRole `toml:"role" yaml:"role" json:"role"` + PoolSize int `toml:"pool_size" yaml:"pool_size" json:"pool_size"` + StatementTimeout int `toml:"statement_timeout" yaml:"statement_timeout" json:"statement_timeout"` } -func (o Server) Port() uint16 { - if v, ok := o[1].(int); ok { - return uint16(v) - } - return 5432 +type Shard struct { + Database string `toml:"database" yaml:"database" json:"database"` + Servers []*Server `toml:"servers" yaml:"servers" json:"servers"` } -func (o Server) Role() ServerRole { - if v, ok := o[2].(string); ok { - switch ServerRole(v) { - case SERVERROLE_PRIMARY, SERVERROLE_REPLICA: - return ServerRole(v) - default: - } - } - return ServerRole(SERVERROLE_NONE) -} +type Server struct { + Host string `toml:"host" yaml:"host" json:"host"` + Port uint16 `toml:"port" yaml:"port" json:"port"` + Role ServerRole `toml:"role" yaml:"role" json:"role"` -func Load(path string) (conf *Global, err error) { - conf = new(Global) - var f []byte - f, err = os.ReadFile(path) - if err != nil { - return - } - err = toml.Unmarshal(f, conf) - return + Username string `toml:"username" yaml:"username" json:"username"` + Password string `toml:"password" yaml:"password" json:"password"` } diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 43d99a1dbf213da3bd61265945887985eeed7526..a6ff23d80b774a848f14e2f6489482f150663017 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -8,6 +8,11 @@ import ( "crypto/tls" "errors" "fmt" + "io" + "math/big" + "net" + "reflect" + "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/gatling/messages" @@ -16,10 +21,6 @@ import ( "git.tuxpa.in/a/zlog" "git.tuxpa.in/a/zlog/log" "github.com/ethereum/go-ethereum/common/math" - "io" - "math/big" - "net" - "reflect" ) // / client state, one per client @@ -239,7 +240,6 @@ func (c *Client) Accept(ctx context.Context) error { return err } - c.log.Debug().Msg("Password authentication successful") authOk := new(protocol.Authentication) authOk.Fields.Code = 0 _, err = authOk.Write(c.wr) @@ -268,7 +268,6 @@ func (c *Client) Accept(ctx context.Context) error { if err != nil { return err } - c.log.Debug().Msg("Ready for Query") go c.recvLoop() open := true for open { @@ -307,7 +306,6 @@ func (c *Client) tick(ctx context.Context) (bool, error) { case <-ctx.Done(): return false, ctx.Err() } - switch cast := rsp.(type) { case *protocol.Query: return true, c.handle_query(ctx, cast) @@ -321,30 +319,18 @@ func (c *Client) tick(ctx context.Context) (bool, error) { } func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { - done, err := c.server.Query(c, ctx, q.Fields.Query) + err := c.server.Query(ctx, c, q.Fields.Query) if err != nil { return err } - // wait for query to finish - <-done.Done() - err = done.Err() - if errors.Is(err, context.Canceled) { - return nil - } return err } func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) error { - done, err := c.server.CallFunction(c, ctx, f) + err := c.server.CallFunction(ctx, c, f) if err != nil { return err } - // wait for function call to finish - <-done.Done() - err = done.Err() - if errors.Is(err, context.Canceled) { - return nil - } return err } diff --git a/lib/gat/gatling/conn_pool/conn_pool.go b/lib/gat/gatling/conn_pool/conn_pool.go index 0c69a9df590c04ee6d706ce7d80e046a93f0ca42..d4b6b74306a7c932502592dc7258359438a2dc54 100644 --- a/lib/gat/gatling/conn_pool/conn_pool.go +++ b/lib/gat/gatling/conn_pool/conn_pool.go @@ -3,16 +3,16 @@ package conn_pool import ( "context" "fmt" - "gfx.cafe/gfx/pggat/lib/config" - "gfx.cafe/gfx/pggat/lib/gat" - "gfx.cafe/gfx/pggat/lib/gat/gatling/server" - "gfx.cafe/gfx/pggat/lib/gat/protocol" "log" "math/rand" "reflect" - "strconv" + "runtime" "sync" - "time" + + "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/gatling/conn_pool/server" + "gfx.cafe/gfx/pggat/lib/gat/protocol" ) type request[T any] struct { @@ -52,41 +52,44 @@ type shard struct { } type ConnectionPool struct { - c *config.Pool - user *config.User - pool gat.Pool - shards []shard - queries chan request[string] - functionCalls chan request[*protocol.FunctionCall] - + // the pool connection + c *config.Pool + user *config.User + pool gat.Pool + shards []shard + + workerPool chan *worker + // the lock for config related things mu sync.RWMutex } func NewConnectionPool(pool gat.Pool, conf *config.Pool, user *config.User) *ConnectionPool { p := &ConnectionPool{ - user: user, - pool: pool, - queries: make(chan request[string]), - functionCalls: make(chan request[*protocol.FunctionCall]), + user: user, + pool: pool, + workerPool: make(chan *worker, 1+runtime.NumCPU()*4), } p.EnsureConfig(conf) for i := 0; i < user.PoolSize; i++ { - go p.worker() + p.add_pool() } return p } +func (c *ConnectionPool) add_pool() { + select { + case c.workerPool <- &worker{ + w: c, + }: + default: + } +} + func (c *ConnectionPool) EnsureConfig(conf *config.Pool) { c.mu.Lock() defer c.mu.Unlock() - c.c = conf - for idx, s := range conf.Shards { - i, err := strconv.Atoi(idx) - if err != nil { - log.Printf("expected shard name to be a number, found '%s'", idx) - continue - } + for i, s := range conf.Shards { for i >= len(c.shards) { c.shards = append(c.shards, shard{}) } @@ -94,7 +97,7 @@ func (c *ConnectionPool) EnsureConfig(conf *config.Pool) { if !reflect.DeepEqual(c.shards[i].conf, &sc) { // disconnect all connections, switch to new conf c.shards[i].servers = nil - c.shards[i].conf = &sc + c.shards[i].conf = sc } } } @@ -115,103 +118,47 @@ func (c *ConnectionPool) chooseShard() *shard { func (c *ConnectionPool) chooseServer() *servers { s := c.chooseShard() if s == nil { - log.Println("no available shard for query!") + log.Println("no available shard for query :(") return nil } - + // lock the shard s.mu.Lock() defer s.mu.Unlock() - - // TODO ideally this would choose the server based on load, capabilities, etc + // TODO ideally this would choose the server based on load, capabilities, etc. for now we just trylock for _, srv := range s.servers { if srv.mu.TryLock() { return srv } } - // there are no servers available in the pool, let's make a new connection - // connect to primary server srvs := &servers{} for _, srvConf := range s.conf.Servers { - srv, err := server.Dial(context.Background(), fmt.Sprintf("%s:%d", srvConf.Host(), srvConf.Port()), c.user, s.conf.Database, nil) + srv, err := server.Dial( + context.Background(), + fmt.Sprintf("%s:%d", srvConf.Host, srvConf.Port), + c.user, s.conf.Database, + srvConf.Username, srvConf.Password, + nil) if err != nil { log.Println("failed to connect to server", err) continue } - switch srvConf.Role() { + switch srvConf.Role { case config.SERVERROLE_PRIMARY: srvs.primary = srv case config.SERVERROLE_REPLICA: srvs.replica = srv } } - if srvs.primary == nil { return nil } - srvs.mu.Lock() - s.servers = append(s.servers, srvs) - return srvs } -func (c *ConnectionPool) worker() { - for { - func() { - select { - case q := <-c.queries: - defer q.done() - srv := c.chooseServer() - if srv == nil { - log.Printf("call to query '%s' failed", q.payload) - return - } - - defer srv.mu.Unlock() - - // run the query - which, err := c.pool.GetRouter().InferRole(q.payload) - if err != nil { - log.Println("error parsing query:", err) - return - } - target := srv.choose(which) - if target == nil { - log.Printf("call to query '%s' failed", q.payload) - return - } - err = target.Query(q.client, q.ctx, q.payload) - if err != nil { - log.Println("error executing query:", err) - } - case f := <-c.functionCalls: - defer f.done() - srv := c.chooseServer() - if srv == nil { - log.Printf("function call '%+v' failed", f.payload) - return - } - - defer srv.mu.Unlock() - - // call the function - target := srv.primary - if target == nil { - log.Printf("function call '%+v' failed", f.payload) - return - } - err := srv.primary.CallFunction(f.client, f.payload) - if err != nil { - log.Println("error calling function:", err) - } - } - }() - } -} - func (c *ConnectionPool) GetUser() *config.User { return c.user } @@ -225,33 +172,12 @@ func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus { return srv.primary.GetServerInfo() } -func (c *ConnectionPool) Query(client gat.Client, ctx context.Context, q string) (context.Context, error) { - // note: these deadlines aren't the time to complete the query, that should be controlled by postgres - // instead, this is the amount of time we allow the client to send extra data for a command - // (mainly so the server doesn't hang indefinitely waiting for client data) - cmdCtx, done := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) - - c.queries <- request[string]{ - client: client, - payload: q, - ctx: cmdCtx, - done: done, - } - - return cmdCtx, nil +func (c *ConnectionPool) Query(ctx context.Context, client gat.Client, q string) error { + return (<-c.workerPool).HandleQuery(ctx, client, q) } -func (c *ConnectionPool) CallFunction(client gat.Client, ctx context.Context, f *protocol.FunctionCall) (context.Context, error) { - cmdCtx, done := context.WithDeadline(ctx, time.Now().Add(1*time.Second)) - - c.functionCalls <- request[*protocol.FunctionCall]{ - client: client, - payload: f, - ctx: cmdCtx, - done: done, - } - - return cmdCtx, nil +func (c *ConnectionPool) CallFunction(ctx context.Context, client gat.Client, f *protocol.FunctionCall) error { + return (<-c.workerPool).HandleFunction(ctx, client, f) } var _ gat.ConnectionPool = (*ConnectionPool)(nil) diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/conn_pool/server/server.go similarity index 88% rename from lib/gat/gatling/server/server.go rename to lib/gat/gatling/conn_pool/server/server.go index c8a5a605928bf7dde59b2f58796753b66c732a54..cbe6eb338aafac058b543a90d5b9685b6acafce2 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/conn_pool/server/server.go @@ -5,12 +5,14 @@ import ( "bytes" "errors" "fmt" - "gfx.cafe/gfx/pggat/lib/gat" - "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" "io" "net" + "strings" "time" + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" + "gfx.cafe/gfx/pggat/lib/gat/protocol" "gfx.cafe/gfx/pggat/lib/util/slices" "gfx.cafe/util/go/bufpool" @@ -46,15 +48,24 @@ type Server struct { last_activity time.Time - db string - user config.User + db string + dbuser string + dbpass string + user config.User log zlog.Logger } -func Dial(ctx context.Context, addr string, user *config.User, db string, stats any) (*Server, error) { +func Dial(ctx context.Context, + addr string, + user *config.User, + db string, dbuser string, dbpass string, + stats any, +) (*Server, error) { s := &Server{ - addr: addr, + addr: addr, + dbuser: dbuser, + dbpass: dbpass, } var err error s.conn, err = net.Dial("tcp", addr) @@ -100,7 +111,7 @@ func (s *Server) startup(ctx context.Context) error { start.Fields.Parameters = []protocol.FieldsStartupMessageParameters{ { Name: "user", - Value: s.user.Name, + Value: s.dbuser, }, { Name: "database", @@ -133,17 +144,14 @@ func (s *Server) connect(ctx context.Context) error { case *protocol.Authentication: switch p.Fields.Code { case 5: //MD5_ENCRYPTED_PASSWORD - case 0: // AUTH SUCCESS case 10: // SASL - s.log.Debug().Msg("starting sasl auth") if slices.Contains(p.Fields.SASLMechanism, scram.SHA256.Name()) { s.log.Debug().Str("method", "scram256").Msg("valid protocol") } else { return fmt.Errorf("unsupported scram version: %s", p.Fields.SASLMechanism) } - - scrm, err = scram.Mechanism(scram.SHA256, s.user.Name, s.user.Password) + scrm, err = scram.Mechanism(scram.SHA256, s.dbuser, s.dbpass) if err != nil { return err } @@ -152,7 +160,6 @@ func (s *Server) connect(ctx context.Context) error { if err != nil { return err } - func() { rsp := new(protocol.AuthenticationResponse) buf := bufpool.Get(len(scrm.Name()) + 1 + 4 + len(bts)) @@ -230,7 +237,7 @@ func (s *Server) forwardTo(client gat.Client, predicate func(pkt protocol.Packet } } -func (s *Server) Query(client gat.Client, ctx context.Context, query string) error { +func (s *Server) Query(ctx context.Context, client gat.Client, query string) error { // send to server q := new(protocol.Query) q.Fields.Query = query @@ -238,20 +245,27 @@ func (s *Server) Query(client gat.Client, ctx context.Context, query string) err if err != nil { return err } - + if strings.Contains(query, "pg_sleep") { + go func() { + time.Sleep(1 * time.Second) + log.Println("cancel: ", s.Cancel()) + }() + } // this function seems wild but it has to be the way it is so we read the whole response, even if the // client fails midway // read responses e := s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + log.Println(pkt) switch r := pkt.(type) { case *protocol.ReadyForQuery: + return err == nil, r.Fields.Status == 'I' case *protocol.CopyInResponse: err = client.Send(pkt) if err != nil { return false, false } - err = s.CopyIn(client, ctx) + err = s.CopyIn(ctx, client) if err != nil { return false, false } @@ -266,19 +280,26 @@ func (s *Server) Query(client gat.Client, ctx context.Context, query string) err return err } -func (s *Server) CopyIn(client gat.Client, ctx context.Context) error { +func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { for { + cctx, cancel := context.WithTimeout(ctx, 5*time.Second) var pkt protocol.Packet + // receive a packet, or done if the ctx gets canceled select { case pkt = <-client.Recv(): - case <-ctx.Done(): + case <-cctx.Done(): _, _ = new(protocol.CopyFail).Write(s.wr) - return ctx.Err() + rfq := new(protocol.ReadyForQuery) + rfq.Fields.Status = 'I' + return client.Send(rfq) } + cancel() + _, err := pkt.Write(s.wr) if err != nil { return err } + switch p := pkt.(type) { case *protocol.CopyDone: return nil diff --git a/lib/gat/gatling/server/server_test.go b/lib/gat/gatling/conn_pool/server/server_test.go similarity index 100% rename from lib/gat/gatling/server/server_test.go rename to lib/gat/gatling/conn_pool/server/server_test.go diff --git a/lib/gat/gatling/conn_pool/worker.go b/lib/gat/gatling/conn_pool/worker.go new file mode 100644 index 0000000000000000000000000000000000000000..e3d0ea6096af4d5e61fbc971816687ae5b021719 --- /dev/null +++ b/lib/gat/gatling/conn_pool/worker.go @@ -0,0 +1,104 @@ +package conn_pool + +import ( + "context" + "fmt" + "log" + + "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/protocol" +) + +type _wp ConnectionPool + +// a single use worker with an embedded connection pool. +// it wraps a pointer to the connection pool. +type worker struct { + // the parent connectino pool + w *ConnectionPool +} + +func (w *worker) HandleFunction(ctx context.Context, c gat.Client, fn *protocol.FunctionCall) error { + log.Println("worker selected for fn") + defer func() { + // return self to the connection pool after + log.Println("worker returned for fn") + w.w.workerPool <- w + }() + + errch := make(chan error) + go func() { + err := w.z_actually_do_fn(ctx, c, fn) + if err != nil { + ctx.Done() + } + errch <- err + close(errch) + }() + return <-errch +} + +func (w *worker) HandleQuery(ctx context.Context, c gat.Client, query string) error { + defer func() { + // return self to the connection pool after + w.w.workerPool <- w + }() + errch := make(chan error) + go func() { + err := w.z_actually_do_query(ctx, c, query) + errch <- err + }() + + // wait until query or close + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errch: + return err + } +} + +func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { + c := w.w + srv := c.chooseServer() + if srv == nil { + return fmt.Errorf("fn('%+v') fail: no server", payload) + } + defer srv.mu.Unlock() + // call the function + target := srv.primary + if target == nil { + return fmt.Errorf("fn('%+v') fail: no target ", payload) + } + err := srv.primary.CallFunction(client, payload) + if err != nil { + return fmt.Errorf("fn('%+v') fail: %w ", payload, err) + } + return nil +} +func (w *worker) z_actually_do_query(ctx context.Context, client gat.Client, payload string) error { + c := w.w + // chose a server + srv := c.chooseServer() + if srv == nil { + return fmt.Errorf("call to query '%s' failed", payload) + } + // note that the server comes locked. you MUST unlock it + defer srv.mu.Unlock() + // run the query on the server + which, err := c.pool.GetRouter().InferRole(payload) + if err != nil { + return fmt.Errorf("error parsing '%s': %w", payload, err) + } + // configures the server to run with a specific role + target := srv.choose(which) + if target == nil { + return fmt.Errorf("call to query '%s' failed", payload) + } + // actually do the query + err = target.Query(ctx, client, payload) + if err != nil { + return fmt.Errorf("error executing query: %w", err) + } + return nil +} diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index cb28256e7212850b11ea598465fa032b796dcf60..9e8d4e187a03dd0dc01714f5cf7cf14b79a7e33a 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -3,12 +3,13 @@ package gatling import ( "context" "fmt" + "net" + "sync" + "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/gatling/client" "gfx.cafe/gfx/pggat/lib/gat/gatling/pool" "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" - "net" - "sync" "git.tuxpa.in/a/zlog/log" @@ -89,9 +90,9 @@ func (g *Gatling) ensureAdmin(c *config.Global) error { func (g *Gatling) ensurePools(c *config.Global) error { for name, p := range c.Pools { if existing, ok := g.pools[name]; ok { - existing.EnsureConfig(&p) + existing.EnsureConfig(p) } else { - g.pools[name] = pool.NewPool(&p) + g.pools[name] = pool.NewPool(p) } } return nil diff --git a/lib/gat/gatling/pool/pool.go b/lib/gat/gatling/pool/pool.go index 49a4ff7881935274d4b0fe305092988a69a83981..79aeca62e6548aca830d832995ce37db76e8e1ec 100644 --- a/lib/gat/gatling/pool/pool.go +++ b/lib/gat/gatling/pool/pool.go @@ -2,11 +2,12 @@ package pool import ( "fmt" + "sync" + "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/gatling/conn_pool" "gfx.cafe/gfx/pggat/lib/gat/gatling/query_router" - "sync" ) type Pool struct { @@ -33,7 +34,7 @@ func (p *Pool) EnsureConfig(conf *config.Pool) { p.c = conf p.users = make(map[string]config.User) for _, user := range conf.Users { - p.users[user.Name] = user + p.users[user.Name] = *user } // ensure conn pools for name, user := range p.users { diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index aa4c8dad1d90b41f949f098a70cf388d7028388d..c342a7702b41ae76d555eb20f501a5c1cd7fb39f 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -15,8 +15,8 @@ type Client interface { type ConnectionPool interface { GetUser() *config.User GetServerInfo() []*protocol.ParameterStatus - Query(client Client, ctx context.Context, query string) (context.Context, error) - CallFunction(client Client, ctx context.Context, payload *protocol.FunctionCall) (context.Context, error) + Query(ctx context.Context, client Client, query string) error + CallFunction(ctx context.Context, client Client, payload *protocol.FunctionCall) error } type Gat interface { diff --git a/lib/util/maps/maps.go b/lib/util/maps/maps.go index 95e2f7953cd9a7c640f97fffd3a467ef14f0ab63..5a5a162bd4bc46deb6f5f364143a42ccdb828c28 100644 --- a/lib/util/maps/maps.go +++ b/lib/util/maps/maps.go @@ -6,7 +6,5 @@ func FirstWhere[K comparable, V any](haystack map[K]V, predicate func(K, V) bool return k, v, true } } - var k K - var v V - return k, v, false + return *new(K), *new(V), false }