diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 8e3426230444c7a0ebe2da69dc3698b4304ef7f0..5906e5fac7425f301540c6c20456923e3a82c8c4 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -31,6 +31,7 @@ type Server struct { conn net.Conn r *bufio.Reader wr *bufio.Writer + recv <-chan protocol.Packet client gat.Client state gat.ConnectionState @@ -88,6 +89,18 @@ func Dial(ctx context.Context, options []protocol.FieldsStartupMessageParameters } s.r = bufio.NewReader(s.conn) s.wr = bufio.NewWriter(s.conn) + recv := make(chan protocol.Packet, 1024) + s.recv = recv + go func() { + for { + p, err := protocol.ReadBackend(s.r) + if err != nil { + _ = s.Close() + break + } + recv <- p + } + }() s.user = *user s.db = shard.Database @@ -244,7 +257,7 @@ func (s *Server) connect(ctx context.Context) error { var sm sasl.StateMachine for { var pkt protocol.Packet - pkt, err = s.readPacket() + pkt, err = s.readPacket(ctx) if err != nil { return err } @@ -359,13 +372,15 @@ func (s *Server) flush() error { } } -func (s *Server) readPacket() (protocol.Packet, error) { - p, err := protocol.ReadBackend(s.r) - if err != nil { - _ = s.Close() +func (s *Server) readPacket(ctx context.Context) (protocol.Packet, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-s.closed: + return nil, net.ErrClosed + case pk := <-s.recv: + return pk, nil } - //log.Printf("in %#v", p) - return p, err } func (s *Server) stabilize() error { @@ -410,7 +425,7 @@ func (s *Server) stabilize() error { for { var pkt protocol.Packet - pkt, err = s.readPacket() + pkt, err = s.readPacket(context.Background()) if err != nil { return err } @@ -448,7 +463,7 @@ func (s *Server) stabilize() error { } } -func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { +func (s *Server) ensurePreparedStatement(ctx context.Context, client gat.Client, name string) error { s.awaitingSync = true // send prepared statement stmt := client.GetPreparedStatement(name) @@ -469,7 +484,7 @@ func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { } // there is a statement bound that needs to be unbound - s.destructPreparedStatement(name) + s.destructPreparedStatement(ctx, name) } } @@ -479,7 +494,7 @@ func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { return s.writePacket(stmt) } -func (s *Server) ensurePortal(client gat.Client, name string) error { +func (s *Server) ensurePortal(ctx context.Context, client gat.Client, name string) error { s.awaitingSync = true portal := client.GetPortal(name) if portal == nil { @@ -490,7 +505,7 @@ func (s *Server) ensurePortal(client gat.Client, name string) error { } } - err := s.ensurePreparedStatement(client, portal.Fields.PreparedStatement) + err := s.ensurePreparedStatement(ctx, client, portal.Fields.PreparedStatement) if err != nil { return err } @@ -507,7 +522,7 @@ func (s *Server) ensurePortal(client gat.Client, name string) error { return s.writePacket(portal) } -func (s *Server) destructPreparedStatement(name string) { +func (s *Server) destructPreparedStatement(ctx context.Context, name string) { if name == "" { return } @@ -519,23 +534,23 @@ func (s *Server) destructPreparedStatement(name string) { _ = s.flush() // await server ready for { - r, _ := s.readPacket() + r, _ := s.readPacket(ctx) if _, ok := r.(*protocol.ReadyForQuery); ok { return } } } -func (s *Server) destructPortal(name string) { +func (s *Server) destructPortal(ctx context.Context, name string) { portal, ok := s.boundPortals[name] if !ok { return } delete(s.boundPortals, name) - s.destructPreparedStatement(portal.Fields.PreparedStatement) + s.destructPreparedStatement(ctx, portal.Fields.PreparedStatement) } -func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { +func (s *Server) handleRecv(ctx context.Context, client gat.Client, packet protocol.Packet) error { switch pkt := packet.(type) { case *protocol.FunctionCall, *protocol.Query: err := s.writePacket(packet) @@ -550,12 +565,12 @@ func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { s.awaitingSync = true switch pkt.Fields.Which { case 'S': // prepared statement - err := s.ensurePreparedStatement(client, pkt.Fields.Name) + err := s.ensurePreparedStatement(ctx, client, pkt.Fields.Name) if err != nil { return err } case 'P': // portal - err := s.ensurePortal(client, pkt.Fields.Name) + err := s.ensurePortal(ctx, client, pkt.Fields.Name) if err != nil { return err } @@ -574,7 +589,7 @@ func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { } case *protocol.Execute: s.awaitingSync = true - err := s.ensurePortal(client, pkt.Fields.Name) + err := s.ensurePortal(ctx, client, pkt.Fields.Name) if err != nil { return err } @@ -601,7 +616,7 @@ func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { func (s *Server) sendAndLink(ctx context.Context, client gat.Client, initial protocol.Packet) error { s.readyForQuery = false - err := s.handleRecv(client, initial) + err := s.handleRecv(ctx, client, initial) if err != nil { return err } @@ -620,7 +635,7 @@ func (s *Server) link(ctx context.Context, client gat.Client) error { } }() for { - pkt, err := s.readPacket() + pkt, err := s.readPacket(ctx) if err != nil { return err } @@ -679,7 +694,7 @@ func (s *Server) link(ctx context.Context, client gat.Client) error { func (s *Server) handleClientPacket(ctx context.Context, client gat.Client) error { select { case pkt := <-client.Recv(): - return s.handleRecv(client, pkt) + return s.handleRecv(ctx, client, pkt) case <-ctx.Done(): return ctx.Err() }