diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index c86501514111709645d98bf64cb631450b4f4e8d..de04bfb4ee4ef9465b050162ee5dae08ba50ede9 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -211,6 +211,10 @@ func (c *Client) SetCurrentConn(conn gat.Connection) { } func (c *Client) Accept(ctx context.Context) error { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + // read a packet startup := new(protocol.StartupMessage) err := startup.Read(c.r) @@ -406,7 +410,7 @@ func (c *Client) Accept(ctx context.Context) error { if err != nil { return err } - go c.recvLoop() + go c.recvLoop(cancel) open := true for open { err = c.Flush() @@ -435,7 +439,8 @@ func (c *Client) Accept(ctx context.Context) error { return nil } -func (c *Client) recvLoop() { +func (c *Client) recvLoop(cancel context.CancelFunc) { + defer cancel() for { recv, err := protocol.ReadFrontend(c.r) if err != nil { diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index 6e9f8cb015e6b5f979b4d83144729fcc14ae7048..bbe640d2ed4eac2e7437311f40a5ccc9219cfb99 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -8,6 +8,8 @@ import ( "net" "sync" + "gfx.cafe/util/go/generic" + "gfx.cafe/gfx/pggat/lib/gat/admin" "gfx.cafe/gfx/pggat/lib/gat/database" "gfx.cafe/gfx/pggat/lib/gat/gatling/server" @@ -31,14 +33,13 @@ type Gatling struct { chConfig chan *config.Global pools map[string]gat.Database - clients map[gat.ClientID]gat.Client + clients generic.Map[gat.ClientID, gat.Client] } func NewGatling(conf *config.Global) *Gatling { g := &Gatling{ chConfig: make(chan *config.Global, 1), pools: make(map[string]gat.Database), - clients: make(map[gat.ClientID]gat.Client), } // add admin pool adminPool := admin.New(g) @@ -88,24 +89,18 @@ func (g *Gatling) GetDatabases() map[string]gat.Database { } func (g *Gatling) GetClient(id gat.ClientID) gat.Client { - g.mu.RLock() - defer g.mu.RUnlock() - c, ok := g.clients[id] + c, ok := g.clients.Load(id) if !ok { return nil } return c } -func (g *Gatling) GetClients() []gat.Client { - g.mu.RLock() - defer g.mu.RUnlock() - out := make([]gat.Client, len(g.clients)) - idx := 0 - for _, p := range g.clients { - out[idx] = p - idx += 1 - } +func (g *Gatling) GetClients() (out []gat.Client) { + g.clients.Range(func(id gat.ClientID, client gat.Client) bool { + out = append(out, client) + return true + }) return out } @@ -164,12 +159,7 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { } metrics.RecordAcceptConnectionStatus(err) close(errch) - err = g.handleConnection(ctx, c) - if err != nil { - if err != io.EOF { - log.Println("disconnected:", err) - } - } + g.handleConnection(ctx, c) }() err = <-errch @@ -183,20 +173,14 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { } // TODO: TLS -func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { +func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) { cl := client.NewClient(g, g.c, c, false) - func() { - g.mu.Lock() - defer g.mu.Unlock() - g.clients[cl.GetId()] = cl - metrics.RecordActiveConnections(len(g.clients)) - }() + g.clients.Store(cl.GetId(), cl) + metrics.RecordActiveConnections(1) defer func() { - g.mu.Lock() - defer g.mu.Unlock() - delete(g.clients, cl.GetId()) - metrics.RecordActiveConnections(len(g.clients)) + g.clients.Delete(cl.GetId()) + metrics.RecordActiveConnections(-1) }() err := cl.Accept(ctx) @@ -208,7 +192,6 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { } } _ = c.Close() - return nil } var _ gat.Gat = (*Gatling)(nil) diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 7ef62ee337c6bf267256cd4683c927c3b1789dcc..bd4720b4ea3e85c7ae991d2726c2e40daa94c4b3 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "errors" "fmt" "net" "reflect" @@ -52,8 +53,10 @@ type Server struct { dbpass string user config.User - healthy bool - awaitingSync bool + healthy bool + awaitingSync bool + readyForQuery bool + copying bool log zlog.Logger @@ -155,39 +158,8 @@ func (s *Server) failHealthCheck(err error) { } func (s *Server) healthCheck() { - check := new(protocol.Query) - check.Fields.Query = "select 1" - err := s.writePacket(check) - if err != nil { - s.failHealthCheck(err) - return - } - err = s.flush() - if err != nil { - s.failHealthCheck(err) - return - } - - // read until we get a ready for query - for { - var recv protocol.Packet - recv, err = s.readPacket() - if err != nil { - s.failHealthCheck(err) - return - } - - switch r := recv.(type) { - case *protocol.ReadyForQuery: - if r.Fields.Status != 'I' { - s.failHealthCheck(fmt.Errorf("expected server to be in command mode but it isn't")) - } - return - case *protocol.DataRow, *protocol.RowDescription, *protocol.CommandComplete: - default: - s.failHealthCheck(fmt.Errorf("expected a Simple Query packet but server sent %#v", recv)) - return - } + if !s.readyForQuery { + s.failHealthCheck(errors.New("expected server to be ready for query")) } } @@ -346,31 +318,12 @@ func (s *Server) connect(ctx context.Context) error { s.lastActivity = time.Now() s.connectedAt = time.Now().UTC() s.state = "idle" + s.readyForQuery = true return nil } } } -func (s *Server) forwardTo(client gat.Client, predicate func(pkt protocol.Packet) (forward bool, finish bool, err error)) error { - var e error - for { - var rsp protocol.Packet - rsp, err := s.readPacket() - if err != nil { - return err - } - //log.Printf("backend packet(%s) %+v", reflect.TypeOf(rsp), rsp) - var forward, finish bool - forward, finish, e = predicate(rsp) - if forward && e == nil { - e = client.Send(rsp) - } - if finish { - return e - } - } -} - func (s *Server) writePacket(pkt protocol.Packet) error { //log.Printf("out %#v", pkt) _, err := pkt.Write(s.wr) @@ -388,10 +341,78 @@ func (s *Server) readPacket() (protocol.Packet, error) { } func (s *Server) stabilize() { - // TODO actually stabilize connection + if s.readyForQuery { + return + } + //log.Println("connection is unstable, attempting to restabilize it") + if s.copying { + //log.Println("failing copy") + s.copying = false + err := s.writePacket(new(protocol.CopyFail)) + if err != nil { + return + } + } if s.awaitingSync { - _ = s.writePacket(new(protocol.Sync)) - _ = s.flush() + //log.Println("syncing") + s.awaitingSync = false + err := s.writePacket(new(protocol.Sync)) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + } + query := new(protocol.Query) + query.Fields.Query = "end" + err := s.writePacket(query) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + + for { + var pkt protocol.Packet + pkt, err = s.readPacket() + if err != nil { + return + } + + //log.Printf("received %+v", pkt) + + switch pk := pkt.(type) { + case *protocol.ReadyForQuery: + if pk.Fields.Status == 'I' { + s.readyForQuery = true + return + } else { + query := new(protocol.Query) + query.Fields.Query = "end" + err = s.writePacket(query) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + } + case *protocol.CopyInResponse, *protocol.CopyBothResponse: + fail := new(protocol.CopyFail) + err = s.writePacket(fail) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + } } } @@ -479,10 +500,6 @@ func (s *Server) destructPortal(name string) { s.destructPreparedStatement(portal.Fields.PreparedStatement) } -func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { - return s.sendAndLink(ctx, client, d) -} - func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { switch pkt := packet.(type) { case *protocol.FunctionCall, *protocol.Query: @@ -546,6 +563,7 @@ func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { } func (s *Server) sendAndLink(ctx context.Context, client gat.Client, initial protocol.Packet) error { + s.readyForQuery = false err := s.handleRecv(client, initial) if err != nil { return err @@ -571,6 +589,7 @@ func (s *Server) link(ctx context.Context, client gat.Client) error { case *protocol.ReadyForQuery: if p.Fields.Status == 'I' { // this client is done + s.readyForQuery = true return nil } @@ -591,7 +610,7 @@ func (s *Server) link(ctx context.Context, client gat.Client) error { if err != nil { return err } - case *protocol.CopyInResponse: + case *protocol.CopyInResponse, *protocol.CopyBothResponse: err = client.Send(p) if err != nil { return err @@ -632,6 +651,10 @@ func (s *Server) awaitSync(ctx context.Context, client gat.Client) error { return nil } +func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { + return s.sendAndLink(ctx, client, d) +} + func (s *Server) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { return s.sendAndLink(ctx, client, e) } @@ -650,6 +673,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin } func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { + s.copying = true err := client.Flush() if err != nil { return err @@ -660,8 +684,6 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { select { case pkt = <-client.Recv(): case <-ctx.Done(): - _ = s.writePacket(new(protocol.CopyFail)) - _ = s.flush() return ctx.Err() } err = s.writePacket(pkt) @@ -671,6 +693,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { switch pkt.(type) { case *protocol.CopyDone, *protocol.CopyFail: + s.copying = false // don't error on copyfail because the client is the one that errored, it already knows return s.flush() } diff --git a/lib/gat/pool/transaction/worker.go b/lib/gat/pool/transaction/worker.go index 25176cd0aa0fd6445c928a14647a388a42c3333e..bb0dd81482a2253cdc7ce98e56e7c728fda45205 100644 --- a/lib/gat/pool/transaction/worker.go +++ b/lib/gat/pool/transaction/worker.go @@ -12,7 +12,6 @@ import ( "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/pool/transaction/shard" "gfx.cafe/gfx/pggat/lib/gat/protocol" - "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" "gfx.cafe/gfx/pggat/lib/metrics" ) @@ -275,12 +274,11 @@ func (w *worker) z_actually_do_describe(ctx context.Context, client gat.Client, return fmt.Errorf("describe('%+v') fail: no server", payload) } // describe the portal - // we can use a replica because we are just describing what this query will return, query content doesn't matter - // because nothing is actually executed yet - if !w.w.user.Role.CanUse(config.SERVERROLE_REPLICA) { + // have to use primary because it could be executed + if !w.w.user.Role.CanUse(config.SERVERROLE_PRIMARY) { return errors.New("permission denied") } - target := srv.Choose(config.SERVERROLE_REPLICA) + target := srv.GetPrimary() if target == nil { return fmt.Errorf("describe('%+v') fail: no server", payload) } @@ -294,33 +292,10 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p return fmt.Errorf("describe('%+v') fail: no server", payload) } - // get the query text - portal := client.GetPortal(payload.Fields.Name) - if portal == nil { - return &pg_error.Error{ - Severity: pg_error.Err, - Code: pg_error.ProtocolViolation, - Message: fmt.Sprintf("portal '%s' not found", payload.Fields.Name), - } - } - - ps := client.GetPreparedStatement(portal.Fields.PreparedStatement) - if ps == nil { - return &pg_error.Error{ - Severity: pg_error.Err, - Code: pg_error.ProtocolViolation, - Message: fmt.Sprintf("prepared statement '%s' not found", ps.Fields.PreparedStatement), - } - } - - which, err := w.w.database.GetRouter().InferRole(ps.Fields.Query) - if err != nil { - return err - } - if !w.w.user.Role.CanUse(which) { + if !w.w.user.Role.CanUse(config.SERVERROLE_PRIMARY) { return errors.New("permission denied") } - target := srv.Choose(which) + target := srv.GetPrimary() w.setCurrentBinding(client, target) defer w.unsetCurrentBinding(client, target) if target == nil { diff --git a/lib/metrics/gat.go b/lib/metrics/gat.go index f238e1cec4861602367dabcd299ade021fcca986..223d288d4ba578f40e25c022705e46227b7e5db4 100644 --- a/lib/metrics/gat.go +++ b/lib/metrics/gat.go @@ -49,10 +49,10 @@ func RecordAcceptConnectionStatus(err error) { g.ConnectionCounter.Inc() } -func RecordActiveConnections(count int) { +func RecordActiveConnections(change int) { if !On() { return } g := GatMetrics() - g.ActiveConnections.Set(float64(count)) + g.ActiveConnections.Add(float64(change)) }