diff --git a/lib/gat/admin/admin.go b/lib/gat/admin/admin.go index c0854cc94e2fac473834c3bd4ec7879ffd024af9..7363c911551d95fee2d2d8b5fcbf7451d5c80e4c 100644 --- a/lib/gat/admin/admin.go +++ b/lib/gat/admin/admin.go @@ -50,7 +50,7 @@ func getServerInfo(g gat.Gat) []*protocol.ParameterStatus { { Fields: protocol.FieldsParameterStatus{ Parameter: "server_version", - Value: g.Version(), + Value: g.GetVersion(), }, }, { @@ -63,7 +63,7 @@ func getServerInfo(g gat.Gat) []*protocol.ParameterStatus { } func getAdminUser(g gat.Gat) *config.User { - conf := g.Config() + conf := g.GetConfig() return &config.User{ Name: conf.General.AdminUsername, Password: conf.General.AdminPassword, @@ -183,7 +183,7 @@ func (p *Pool) showStats(client gat.Client, totals, averages bool) error { if err != nil { return err } - for name, pl := range p.gat.Pools() { + for name, pl := range p.gat.GetPools() { stats := pl.GetStats() if stats == nil { continue @@ -346,7 +346,7 @@ func (p *Pool) showTotals(client gat.Client) error { var totalXactCount, totalQueryCount, totalWaitCount, totalReceived, totalSent, totalXactTime, totalQueryTime, totalWaitTime int var alive time.Duration - for _, pl := range p.gat.Pools() { + for _, pl := range p.gat.GetPools() { stats := pl.GetStats() if stats == nil { continue @@ -433,24 +433,24 @@ func NewPool(g gat.Gat) *Pool { return out } -func (p *Pool) GetUser(name string) (*config.User, error) { +func (p *Pool) GetUser(name string) *config.User { u := getAdminUser(p.gat) if name != u.Name { - return nil, fmt.Errorf("%w: %s", gat.UserNotFound, name) + return nil } - return u, nil + return u } func (p *Pool) GetRouter() gat.QueryRouter { return nil } -func (p *Pool) WithUser(name string) (gat.ConnectionPool, error) { - conf := p.gat.Config() +func (p *Pool) WithUser(name string) gat.ConnectionPool { + conf := p.gat.GetConfig() if name != conf.General.AdminUsername { - return nil, fmt.Errorf("%w: %s", gat.UserNotFound, name) + return nil } - return p.connPool, nil + return p.connPool } func (p *Pool) ConnectionPools() []gat.ConnectionPool { diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index aaa23518e985e3d7706abe777dcc09ba6132accf..8b06cc5ad4da10ee733f8a0624453166effd5c11 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -20,7 +20,6 @@ import ( "math/big" "net" "reflect" - "strconv" "strings" "sync" "sync/atomic" @@ -69,9 +68,12 @@ type Client struct { recv chan protocol.Packet + state gat.ClientState + pid int32 secretKey int32 + requestTime time.Time connectTime time.Time server gat.ConnectionPool @@ -91,46 +93,34 @@ type Client struct { mu sync.Mutex } -func (c *Client) State() string { - return "TODO" // TODO -} - -func (c *Client) Addr() string { - addr, _, _ := net.SplitHostPort(c.conn.RemoteAddr().String()) - return addr -} - -func (c *Client) Port() int { - // ignore the errors cuz 0 is fine, just for stats - _, port, _ := net.SplitHostPort(c.conn.RemoteAddr().String()) - p, _ := strconv.Atoi(port) - return p +func (c *Client) GetState() gat.ClientState { + c.mu.Lock() + defer c.mu.Unlock() + return c.state } -func (c *Client) LocalAddr() string { - addr, _, _ := net.SplitHostPort(c.conn.LocalAddr().String()) - return addr +func (c *Client) GetAddress() net.Addr { + return c.conn.RemoteAddr() } -func (c *Client) LocalPort() int { - _, port, _ := net.SplitHostPort(c.conn.LocalAddr().String()) - p, _ := strconv.Atoi(port) - return p +func (c *Client) GetLocalAddress() net.Addr { + return c.conn.LocalAddr() } -func (c *Client) ConnectTime() time.Time { +func (c *Client) GetConnectTime() time.Time { return c.connectTime } -func (c *Client) RequestTime() time.Time { - return c.currentConn.RequestTime() +func (c *Client) startRequest() { + c.state = gat.ClientWaiting + c.requestTime = time.Now() } -func (c *Client) Wait() time.Duration { - return c.currentConn.Wait() +func (c *Client) GetRequestTime() time.Time { + return c.requestTime } -func (c *Client) RemotePid() int { +func (c *Client) GetRemotePid() int { return int(c.pid) } @@ -152,6 +142,7 @@ func NewClient( r: NewCountReader(bufio.NewReader(conn)), wr: NewCountWriter(bufio.NewWriter(conn)), recv: make(chan protocol.Packet), + state: gat.ClientActive, pid: int32(pid.Int64()), secretKey: int32(skey.Int64()), gatling: gatling, @@ -172,14 +163,16 @@ func (c *Client) Id() gat.ClientID { } } -func (c *Client) GetCurrentConn() (gat.Connection, error) { - if c.currentConn == nil { - return nil, errors.New("not connected to a server") - } - return c.currentConn, nil +func (c *Client) GetCurrentConn() gat.Connection { + c.mu.Lock() + defer c.mu.Unlock() + return c.currentConn } func (c *Client) SetCurrentConn(conn gat.Connection) { + c.mu.Lock() + defer c.mu.Unlock() + c.state = gat.ClientActive c.currentConn = conn } @@ -307,17 +300,16 @@ func (c *Client) Accept(ctx context.Context) error { } } - var pool gat.Pool - pool, err = c.gatling.GetPool(c.poolName) - if err != nil { - return err + pool := c.gatling.GetPool(c.poolName) + if pool == nil { + return fmt.Errorf("pool '%s' not found", c.poolName) } // get user var user *config.User - user, err = pool.GetUser(c.username) - if err != nil { - return err + user = pool.GetUser(c.username) + if user == nil { + return fmt.Errorf("user '%s' not found", c.username) } // Authenticate admin user. @@ -341,9 +333,9 @@ func (c *Client) Accept(ctx context.Context) error { } } - c.server, err = pool.WithUser(c.username) - if err != nil { - return err + c.server = pool.WithUser(c.username) + if c.server == nil { + return fmt.Errorf("no pool for '%s'", c.username) } authOk := new(protocol.Authentication) @@ -424,17 +416,16 @@ func (c *Client) recvLoop() { } func (c *Client) handle_cancel(ctx context.Context, p *protocol.StartupMessage) error { - cl, err := c.gatling.GetClient(gat.ClientID{ + cl := c.gatling.GetClient(gat.ClientID{ PID: p.Fields.ProcessKey, SecretKey: p.Fields.SecretKey, }) - if err != nil { - return err + if cl == nil { + return errors.New("user not found") } - var conn gat.Connection - conn, err = cl.GetCurrentConn() - if err != nil { - return err + conn := cl.GetCurrentConn() + if conn == nil { + return errors.New("not connected to a server") } return conn.Cancel() } @@ -486,12 +477,14 @@ func (c *Client) bind(ctx context.Context, b *protocol.Bind) error { func (c *Client) handle_describe(ctx context.Context, d *protocol.Describe) error { //log.Println("describe") c.status = 'T' + c.startRequest() return c.server.Describe(ctx, c, d) } func (c *Client) handle_execute(ctx context.Context, e *protocol.Execute) error { //log.Println("execute") c.status = 'T' + c.startRequest() return c.server.Execute(ctx, c, e) } @@ -525,6 +518,7 @@ func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { // begin transaction if prev != cmd.Index { query := q.Fields.Query[prev:cmd.Index] + c.startRequest() err = c.handle_simple_query(ctx, query) prev = cmd.Index if err != nil { @@ -541,6 +535,7 @@ func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { query = q.Fields.Query[prev:parsed[idx+1].Index] } if query != "" { + c.startRequest() err = c.handle_transaction(ctx, query) prev = cmd.Index if err != nil { @@ -552,6 +547,7 @@ func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { } } query := q.Fields.Query[prev:] + c.startRequest() if transaction { err = c.handle_transaction(ctx, query) } else { @@ -562,20 +558,19 @@ func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { func (c *Client) handle_simple_query(ctx context.Context, q string) error { //log.Println("query:", q) + c.startRequest() return c.server.SimpleQuery(ctx, c, q) } func (c *Client) handle_transaction(ctx context.Context, q string) error { //log.Println("transaction:", q) + c.startRequest() return c.server.Transaction(ctx, c, q) } func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) error { - err := c.server.CallFunction(ctx, c, f) - if err != nil { - return err - } - return err + c.startRequest() + return c.server.CallFunction(ctx, c, f) } func (c *Client) GetPreparedStatement(name string) *protocol.Parse { diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index 4d6f6055f38028b2c53a753e79120a9845b672be..115276f0a912474df8aec4c2a56f1f9dd3852bd2 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -2,7 +2,6 @@ package gatling import ( "context" - "errors" "fmt" "gfx.cafe/gfx/pggat/lib/gat/admin" "io" @@ -60,41 +59,41 @@ func (g *Gatling) watchConfigs() { } } -func (g *Gatling) Version() string { +func (g *Gatling) GetVersion() string { return "PgGat Gatling 0.0.1" } -func (g *Gatling) Config() *config.Global { +func (g *Gatling) GetConfig() *config.Global { return g.c } -func (g *Gatling) GetPool(name string) (gat.Pool, error) { +func (g *Gatling) GetPool(name string) gat.Pool { g.mu.RLock() defer g.mu.RUnlock() srv, ok := g.pools[name] if !ok { - return nil, fmt.Errorf("pool '%s' not found", name) + return nil } - return srv, nil + return srv } -func (g *Gatling) Pools() map[string]gat.Pool { +func (g *Gatling) GetPools() map[string]gat.Pool { g.mu.RLock() defer g.mu.RUnlock() return g.pools } -func (g *Gatling) GetClient(id gat.ClientID) (gat.Client, error) { +func (g *Gatling) GetClient(id gat.ClientID) gat.Client { g.mu.RLock() defer g.mu.RUnlock() c, ok := g.clients[id] if !ok { - return nil, errors.New("client not found") + return nil } - return c, nil + return c } -func (g *Gatling) Clients() []gat.Client { +func (g *Gatling) GetClients() []gat.Client { g.mu.RLock() defer g.mu.RUnlock() out := make([]gat.Client, len(g.clients)) diff --git a/lib/gat/gatling/pool/conn_pool/shard/server/server.go b/lib/gat/gatling/pool/conn_pool/shard/server/server.go index e2aa43d48bb2ed950b3cfdb13cae07828fb8cad9..047cb4d3368ec44af87edb22a8bd34e0f23efe0e 100644 --- a/lib/gat/gatling/pool/conn_pool/shard/server/server.go +++ b/lib/gat/gatling/pool/conn_pool/shard/server/server.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "reflect" - "strconv" "sync" "time" @@ -32,7 +31,7 @@ type Server struct { wr *bufio.Writer client gat.Client - state string + state gat.ConnectionState serverInfo []*protocol.ParameterStatus @@ -66,7 +65,7 @@ func Dial(ctx context.Context, addr: addr, port: port, - state: "new", + state: gat.ConnectionNew, boundPreparedStatments: make(map[string]*protocol.Parse), boundPortals: make(map[string]*protocol.Bind), @@ -110,58 +109,49 @@ func (s *Server) GetDatabase() string { return s.db } -func (s *Server) State() string { +func (s *Server) GetState() gat.ConnectionState { s.mu.Lock() defer s.mu.Unlock() return s.state } -func (s *Server) Address() string { +func (s *Server) GetHost() string { return s.addr } -func (s *Server) Port() int { +func (s *Server) GetPort() int { return int(s.port) } -func (s *Server) LocalAddr() string { +func (s *Server) GetAddress() net.Addr { s.mu.Lock() defer s.mu.Unlock() - addr, _, _ := net.SplitHostPort(s.conn.LocalAddr().String()) - return addr + return s.conn.RemoteAddr() } -func (s *Server) LocalPort() int { +func (s *Server) GetLocalAddress() net.Addr { s.mu.Lock() defer s.mu.Unlock() - _, port, _ := net.SplitHostPort(s.conn.LocalAddr().String()) - p, _ := strconv.Atoi(port) - return p + return s.conn.LocalAddr() } -func (s *Server) ConnectTime() time.Time { +func (s *Server) GetConnectTime() time.Time { s.mu.Lock() defer s.mu.Unlock() return s.connectedAt } -func (s *Server) RequestTime() time.Time { +func (s *Server) GetRequestTime() time.Time { s.mu.Lock() defer s.mu.Unlock() return s.lastActivity } -func (s *Server) Wait() time.Duration { - s.mu.Lock() - defer s.mu.Unlock() - return time.Now().Sub(s.lastActivity) -} - -func (s *Server) CloseNeeded() bool { +func (s *Server) IsCloseNeeded() bool { return false } -func (s *Server) Client() gat.Client { +func (s *Server) GetClient() gat.Client { s.mu.Lock() defer s.mu.Unlock() return s.client @@ -171,16 +161,21 @@ func (s *Server) SetClient(client gat.Client) { s.mu.Lock() defer s.mu.Unlock() s.lastActivity = time.Now() + if client != nil { + s.state = gat.ConnectionActive + } else { + s.state = gat.ConnectionIdle + } s.client = client } -func (s *Server) RemotePid() int { +func (s *Server) GetRemotePid() int { s.mu.Lock() defer s.mu.Unlock() return int(s.processId) } -func (s *Server) TLS() string { +func (s *Server) GetTLS() string { return "" // TODO } diff --git a/lib/gat/gatling/pool/conn_pool/shard/shard.go b/lib/gat/gatling/pool/conn_pool/shard/shard.go index 7a301e76687e6d18371e47545452766bb4a73218..e54cad49cb58fcf40be009079b5671474331a9c0 100644 --- a/lib/gat/gatling/pool/conn_pool/shard/shard.go +++ b/lib/gat/gatling/pool/conn_pool/shard/shard.go @@ -48,11 +48,11 @@ func (s *Shard) Choose(role config.ServerRole) gat.Connection { } } -func (s *Shard) Primary() gat.Connection { +func (s *Shard) GetPrimary() gat.Connection { return s.primary } -func (s *Shard) Replicas() []gat.Connection { +func (s *Shard) GetReplicas() []gat.Connection { return s.replicas } diff --git a/lib/gat/gatling/pool/conn_pool/worker.go b/lib/gat/gatling/pool/conn_pool/worker.go index 930f62b42e28bcb56b8de6f1789ca94af8e5799f..a1125b98e0ba4aa44f7e50d385871192bad3a88b 100644 --- a/lib/gat/gatling/pool/conn_pool/worker.go +++ b/lib/gat/gatling/pool/conn_pool/worker.go @@ -95,7 +95,7 @@ func (w *worker) GetServerInfo() []*protocol.ParameterStatus { return nil } - primary := shard.Primary() + primary := shard.GetPrimary() if primary == nil { return nil } @@ -284,7 +284,7 @@ func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payloa return fmt.Errorf("fn('%+v') fail: no server", payload) } // call the function - target := srv.Primary() + target := srv.GetPrimary() if target == nil { return fmt.Errorf("fn('%+v') fail: no target ", payload) } diff --git a/lib/gat/gatling/pool/pool.go b/lib/gat/gatling/pool/pool.go index 4b1d91e180dd662339fbed03d311739c93e56dcc..cc8d33f79b9bcaf2c63edf3a2a8202d882df783a 100644 --- a/lib/gat/gatling/pool/pool.go +++ b/lib/gat/gatling/pool/pool.go @@ -1,7 +1,6 @@ package pool import ( - "fmt" "gfx.cafe/gfx/pggat/lib/gat/gatling/pool/conn_pool" "gfx.cafe/gfx/pggat/lib/gat/gatling/pool/query_router" "sync" @@ -50,28 +49,28 @@ func (p *Pool) EnsureConfig(conf *config.Pool) { } } -func (p *Pool) GetUser(name string) (*config.User, error) { +func (p *Pool) GetUser(name string) *config.User { p.mu.RLock() defer p.mu.RUnlock() user, ok := p.users[name] if !ok { - return nil, fmt.Errorf("%w: %s", gat.UserNotFound, name) + return nil } - return &user, nil + return &user } func (p *Pool) GetRouter() gat.QueryRouter { return &p.router } -func (p *Pool) WithUser(name string) (gat.ConnectionPool, error) { +func (p *Pool) WithUser(name string) gat.ConnectionPool { p.mu.RLock() defer p.mu.RUnlock() pool, ok := p.connPools[name] if !ok { - return nil, fmt.Errorf("no pool for '%s'", name) + return nil } - return pool, nil + return pool } func (p *Pool) ConnectionPools() []gat.ConnectionPool { diff --git a/lib/gat/gatling/pool/query_router/query_router.go b/lib/gat/gatling/pool/query_router/query_router.go index d5e44bb5078e38cc3a654078eab4c573319c8ae5..1bc0bf0fea41cf9fd3baa4382640d0121e153aa6 100644 --- a/lib/gat/gatling/pool/query_router/query_router.go +++ b/lib/gat/gatling/pool/query_router/query_router.go @@ -108,7 +108,7 @@ func (r *QueryRouter) try_execute_command(pkt *protocol.Query) (Command, string) // Command::ShowShard => self.shard().to_string(), // Command::ShowServerRole => switch self.active_role { - // Some(Role::Primary) => string("primary"), + // Some(Role::GetPrimary) => string("primary"), // Some(Role::Replica) => string("replica"), // None => { // if self.query_parser_enabled { @@ -147,7 +147,7 @@ func (r *QueryRouter) try_execute_command(pkt *protocol.Query) (Command, string) // self.active_role := switch value.to_ascii_lowercase().as_ref() { // "primary" => { // self.query_parser_enabled := false - // Some(Role::Primary) + // Some(Role::GetPrimary) // } // "replica" => { diff --git a/lib/gat/gatling/pool/query_router/query_router_test.go b/lib/gat/gatling/pool/query_router/query_router_test.go index 8ae996efc5f739c08bd2c76ebb1550c4f127e476..31b3fd4b244af6c963d689ac326743f137ed33f8 100644 --- a/lib/gat/gatling/pool/query_router/query_router_test.go +++ b/lib/gat/gatling/pool/query_router/query_router_test.go @@ -61,7 +61,7 @@ func TestQueryRouterInterRoleReplica(t *testing.T) { // for query in queries { // // It's a recognized query // assert!(qr.infer_role(query)); -// assert_eq!(qr.role(), Some(Role::Primary)); +// assert_eq!(qr.role(), Some(Role::GetPrimary)); // } // } // @@ -198,11 +198,11 @@ func TestQueryRouterInterRoleReplica(t *testing.T) { // // SetServerRole // let roles = ["primary", "replica", "any", "auto", "primary"]; // let verify_roles = [ -// Some(Role::Primary), +// Some(Role::GetPrimary), // Some(Role::Replica), // None, // None, -// Some(Role::Primary), +// Some(Role::GetPrimary), // ]; // let query_parser_enabled = [false, false, false, true, false]; // @@ -257,7 +257,7 @@ func TestQueryRouterInterRoleReplica(t *testing.T) { // // let query = simple_query("INSERT INTO test_table VALUES (1)"); // assert_eq!(qr.infer_role(query), true); -// assert_eq!(qr.role(), Some(Role::Primary)); +// assert_eq!(qr.role(), Some(Role::GetPrimary)); // // let query = simple_query("SELECT * FROM test_table"); // assert_eq!(qr.infer_role(query), true); @@ -298,7 +298,7 @@ func TestQueryRouterInterRoleReplica(t *testing.T) { // // let q1 = simple_query("SET SERVER ROLE TO 'primary'"); // assert!(qr.try_execute_command(q1) != None); -// assert_eq!(qr.active_role.unwrap(), Role::Primary); +// assert_eq!(qr.active_role.unwrap(), Role::GetPrimary); // // let q2 = simple_query("SET SERVER ROLE TO 'default'"); // assert!(qr.try_execute_command(q2) != None); diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index c682e663787ed545b95d7e97ef7a612dd7e8ec7a..b556f8f3e3d7cf29207c851419d9467eb78fc4a1 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -2,9 +2,9 @@ package gat import ( "context" - "errors" "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat/protocol" + "net" "time" ) @@ -13,23 +13,27 @@ type ClientID struct { SecretKey int32 } +type ClientState string + +const ( + ClientActive ClientState = "active" + ClientWaiting = "waiting" +) + type Client interface { GetPreparedStatement(name string) *protocol.Parse GetPortal(name string) *protocol.Bind - GetCurrentConn() (Connection, error) + GetCurrentConn() Connection SetCurrentConn(conn Connection) GetConnectionPool() ConnectionPool - State() string - Addr() string - Port() int - LocalAddr() string - LocalPort() int - ConnectTime() time.Time - RequestTime() time.Time - Wait() time.Duration - RemotePid() int + GetState() ClientState + GetAddress() net.Addr + GetLocalAddress() net.Addr + GetConnectTime() time.Time + GetRequestTime() time.Time + GetRemotePid() int Send(pkt protocol.Packet) error Flush() error @@ -37,21 +41,19 @@ type Client interface { } type Gat interface { - Version() string - Config() *config.Global - GetPool(name string) (Pool, error) - Pools() map[string]Pool - GetClient(id ClientID) (Client, error) - Clients() []Client + GetVersion() string + GetConfig() *config.Global + GetPool(name string) Pool + GetPools() map[string]Pool + GetClient(id ClientID) Client + GetClients() []Client } -var UserNotFound = errors.New("user not found") - type Pool interface { - GetUser(name string) (*config.User, error) + GetUser(name string) *config.User GetRouter() QueryRouter - WithUser(name string) (ConnectionPool, error) + WithUser(name string) ConnectionPool ConnectionPools() []ConnectionPool GetStats() *PoolStats @@ -84,28 +86,37 @@ type ConnectionPool interface { } type Shard interface { - Primary() Connection - Replicas() []Connection + GetPrimary() Connection + GetReplicas() []Connection Choose(role config.ServerRole) Connection } +type ConnectionState string + +const ( + ConnectionActive ConnectionState = "active" + ConnectionIdle = "idle" + ConnectionUsed = "used" + ConnectionTested = "tested" + ConnectionNew = "new" +) + type Connection interface { GetServerInfo() []*protocol.ParameterStatus GetDatabase() string - State() string - Address() string - Port() int - LocalAddr() string - LocalPort() int - ConnectTime() time.Time - RequestTime() time.Time - Wait() time.Duration - CloseNeeded() bool - Client() Client + GetState() ConnectionState + GetHost() string + GetPort() int + GetAddress() net.Addr + GetLocalAddress() net.Addr + GetConnectTime() time.Time + GetRequestTime() time.Time + IsCloseNeeded() bool + GetClient() Client SetClient(client Client) - RemotePid() int - TLS() string + GetRemotePid() int + GetTLS() string // actions Describe(client Client, payload *protocol.Describe) error