diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index d6d9163c16ffaa7d85d9de94e0866750e936adc0..74ab34185023fba45eb733de824a9395b7aa76a7 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -209,6 +209,10 @@ func (c *Client) SetCurrentConn(conn gat.Connection) { } func (c *Client) Accept(ctx context.Context) error { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + // read a packet startup := new(protocol.StartupMessage) err := startup.Read(c.r) @@ -404,7 +408,7 @@ func (c *Client) Accept(ctx context.Context) error { if err != nil { return err } - go c.recvLoop() + go c.recvLoop(cancel) open := true for open { err = c.Flush() @@ -437,7 +441,8 @@ func (c *Client) Accept(ctx context.Context) error { return nil } -func (c *Client) recvLoop() { +func (c *Client) recvLoop(cancel context.CancelFunc) { + defer cancel() for { recv, err := protocol.ReadFrontend(c.r) if err != nil { diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index 746614d0de4ffe95ed4ac9616730870f73d6cea3..6e1811e6748cb1997562300325127d5ca230078e 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -7,6 +7,7 @@ import ( "gfx.cafe/gfx/pggat/lib/gat/admin" "gfx.cafe/gfx/pggat/lib/gat/database" "gfx.cafe/gfx/pggat/lib/gat/gatling/server" + "gfx.cafe/util/go/generic" "io" "net" "sync" @@ -29,14 +30,13 @@ type Gatling struct { chConfig chan *config.Global pools map[string]gat.Database - clients map[gat.ClientID]gat.Client + clients generic.Map[gat.ClientID, gat.Client] } func NewGatling(conf *config.Global) *Gatling { g := &Gatling{ chConfig: make(chan *config.Global, 1), pools: make(map[string]gat.Database), - clients: make(map[gat.ClientID]gat.Client), } // add admin pool adminPool := admin.New(g) @@ -86,24 +86,18 @@ func (g *Gatling) GetDatabases() map[string]gat.Database { } func (g *Gatling) GetClient(id gat.ClientID) gat.Client { - g.mu.RLock() - defer g.mu.RUnlock() - c, ok := g.clients[id] + c, ok := g.clients.Load(id) if !ok { return nil } return c } -func (g *Gatling) GetClients() []gat.Client { - g.mu.RLock() - defer g.mu.RUnlock() - out := make([]gat.Client, len(g.clients)) - idx := 0 - for _, p := range g.clients { - out[idx] = p - idx += 1 - } +func (g *Gatling) GetClients() (out []gat.Client) { + g.clients.Range(func(id gat.ClientID, client gat.Client) bool { + out = append(out, client) + return true + }) return out } @@ -161,12 +155,7 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { errch <- err } close(errch) - err = g.handleConnection(ctx, c) - if err != nil { - if err != io.EOF { - log.Println("disconnected:", err) - } - } + g.handleConnection(ctx, c) }() err = <-errch @@ -179,18 +168,12 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { } // TODO: TLS -func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { +func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) { cl := client.NewClient(g, g.c, c, false) - func() { - g.mu.Lock() - defer g.mu.Unlock() - g.clients[cl.GetId()] = cl - }() + g.clients.Store(cl.GetId(), cl) defer func() { - g.mu.Lock() - defer g.mu.Unlock() - delete(g.clients, cl.GetId()) + g.clients.Delete(cl.GetId()) }() err := cl.Accept(ctx) @@ -202,7 +185,6 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { } } _ = c.Close() - return nil } var _ gat.Gat = (*Gatling)(nil) diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 7ef62ee337c6bf267256cd4683c927c3b1789dcc..bd4720b4ea3e85c7ae991d2726c2e40daa94c4b3 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "errors" "fmt" "net" "reflect" @@ -52,8 +53,10 @@ type Server struct { dbpass string user config.User - healthy bool - awaitingSync bool + healthy bool + awaitingSync bool + readyForQuery bool + copying bool log zlog.Logger @@ -155,39 +158,8 @@ func (s *Server) failHealthCheck(err error) { } func (s *Server) healthCheck() { - check := new(protocol.Query) - check.Fields.Query = "select 1" - err := s.writePacket(check) - if err != nil { - s.failHealthCheck(err) - return - } - err = s.flush() - if err != nil { - s.failHealthCheck(err) - return - } - - // read until we get a ready for query - for { - var recv protocol.Packet - recv, err = s.readPacket() - if err != nil { - s.failHealthCheck(err) - return - } - - switch r := recv.(type) { - case *protocol.ReadyForQuery: - if r.Fields.Status != 'I' { - s.failHealthCheck(fmt.Errorf("expected server to be in command mode but it isn't")) - } - return - case *protocol.DataRow, *protocol.RowDescription, *protocol.CommandComplete: - default: - s.failHealthCheck(fmt.Errorf("expected a Simple Query packet but server sent %#v", recv)) - return - } + if !s.readyForQuery { + s.failHealthCheck(errors.New("expected server to be ready for query")) } } @@ -346,31 +318,12 @@ func (s *Server) connect(ctx context.Context) error { s.lastActivity = time.Now() s.connectedAt = time.Now().UTC() s.state = "idle" + s.readyForQuery = true return nil } } } -func (s *Server) forwardTo(client gat.Client, predicate func(pkt protocol.Packet) (forward bool, finish bool, err error)) error { - var e error - for { - var rsp protocol.Packet - rsp, err := s.readPacket() - if err != nil { - return err - } - //log.Printf("backend packet(%s) %+v", reflect.TypeOf(rsp), rsp) - var forward, finish bool - forward, finish, e = predicate(rsp) - if forward && e == nil { - e = client.Send(rsp) - } - if finish { - return e - } - } -} - func (s *Server) writePacket(pkt protocol.Packet) error { //log.Printf("out %#v", pkt) _, err := pkt.Write(s.wr) @@ -388,10 +341,78 @@ func (s *Server) readPacket() (protocol.Packet, error) { } func (s *Server) stabilize() { - // TODO actually stabilize connection + if s.readyForQuery { + return + } + //log.Println("connection is unstable, attempting to restabilize it") + if s.copying { + //log.Println("failing copy") + s.copying = false + err := s.writePacket(new(protocol.CopyFail)) + if err != nil { + return + } + } if s.awaitingSync { - _ = s.writePacket(new(protocol.Sync)) - _ = s.flush() + //log.Println("syncing") + s.awaitingSync = false + err := s.writePacket(new(protocol.Sync)) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + } + query := new(protocol.Query) + query.Fields.Query = "end" + err := s.writePacket(query) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + + for { + var pkt protocol.Packet + pkt, err = s.readPacket() + if err != nil { + return + } + + //log.Printf("received %+v", pkt) + + switch pk := pkt.(type) { + case *protocol.ReadyForQuery: + if pk.Fields.Status == 'I' { + s.readyForQuery = true + return + } else { + query := new(protocol.Query) + query.Fields.Query = "end" + err = s.writePacket(query) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + } + case *protocol.CopyInResponse, *protocol.CopyBothResponse: + fail := new(protocol.CopyFail) + err = s.writePacket(fail) + if err != nil { + return + } + err = s.flush() + if err != nil { + return + } + } } } @@ -479,10 +500,6 @@ func (s *Server) destructPortal(name string) { s.destructPreparedStatement(portal.Fields.PreparedStatement) } -func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { - return s.sendAndLink(ctx, client, d) -} - func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { switch pkt := packet.(type) { case *protocol.FunctionCall, *protocol.Query: @@ -546,6 +563,7 @@ func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { } func (s *Server) sendAndLink(ctx context.Context, client gat.Client, initial protocol.Packet) error { + s.readyForQuery = false err := s.handleRecv(client, initial) if err != nil { return err @@ -571,6 +589,7 @@ func (s *Server) link(ctx context.Context, client gat.Client) error { case *protocol.ReadyForQuery: if p.Fields.Status == 'I' { // this client is done + s.readyForQuery = true return nil } @@ -591,7 +610,7 @@ func (s *Server) link(ctx context.Context, client gat.Client) error { if err != nil { return err } - case *protocol.CopyInResponse: + case *protocol.CopyInResponse, *protocol.CopyBothResponse: err = client.Send(p) if err != nil { return err @@ -632,6 +651,10 @@ func (s *Server) awaitSync(ctx context.Context, client gat.Client) error { return nil } +func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { + return s.sendAndLink(ctx, client, d) +} + func (s *Server) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { return s.sendAndLink(ctx, client, e) } @@ -650,6 +673,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin } func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { + s.copying = true err := client.Flush() if err != nil { return err @@ -660,8 +684,6 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { select { case pkt = <-client.Recv(): case <-ctx.Done(): - _ = s.writePacket(new(protocol.CopyFail)) - _ = s.flush() return ctx.Err() } err = s.writePacket(pkt) @@ -671,6 +693,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { switch pkt.(type) { case *protocol.CopyDone, *protocol.CopyFail: + s.copying = false // don't error on copyfail because the client is the one that errored, it already knows return s.flush() }