diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index b0f1243064e76018f49d2803a5a804e8a61aaa44..33a258e3c479ba5538910d1ae16ac244671f0a4a 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -22,13 +22,14 @@ import ( "net" "reflect" "strings" + "sync" ) // / client state, one per client type Client struct { conn net.Conn r *bufio.Reader - wr io.Writer + wr *bufio.Writer recv chan protocol.Packet @@ -62,6 +63,8 @@ type Client struct { state rune log zlog.Logger + + mu sync.Mutex } func NewClient( @@ -76,7 +79,7 @@ func NewClient( c := &Client{ conn: conn, r: bufio.NewReader(conn), - wr: conn, + wr: bufio.NewWriter(conn), recv: make(chan protocol.Packet), addr: conn.RemoteAddr(), pid: int32(pid.Int64()), @@ -129,6 +132,10 @@ func (c *Client) Accept(ctx context.Context) error { if err != nil { return err } + err = c.wr.Flush() + if err != nil { + return err + } startup = new(protocol.StartupMessage) err = startup.Read(c.r) if err != nil { @@ -139,6 +146,10 @@ func (c *Client) Accept(ctx context.Context) error { if err != nil { return err } + err = c.wr.Flush() + if err != nil { + return err + } //TODO: we need to do an ssl handshake here. var cert tls.Certificate cert, err = tls.LoadX509KeyPair(c.conf.General.TlsCertificate, c.conf.General.TlsPrivateKey) @@ -151,7 +162,7 @@ func (c *Client) Accept(ctx context.Context) error { } c.conn = tls.Server(c.conn, cfg) c.r = bufio.NewReader(c.conn) - c.wr = c.conn + c.wr = bufio.NewWriter(c.conn) err = startup.Read(c.r) if err != nil { return err @@ -200,7 +211,11 @@ func (c *Client) Accept(ctx context.Context) error { if err != nil { return err } - _, err = pkt.Write(c.wr) + err = c.Send(pkt) + if err != nil { + return err + } + err = c.Flush() if err != nil { return err } @@ -263,7 +278,7 @@ func (c *Client) Accept(ctx context.Context) error { authOk := new(protocol.Authentication) authOk.Fields.Code = 0 - _, err = authOk.Write(c.wr) + err = c.Send(authOk) if err != nil { return err } @@ -271,7 +286,7 @@ func (c *Client) Accept(ctx context.Context) error { // info := c.server.GetServerInfo() for _, inf := range info { - _, err = inf.Write(c.wr) + err = c.Send(inf) if err != nil { return err } @@ -279,19 +294,23 @@ func (c *Client) Accept(ctx context.Context) error { backendKeyData := new(protocol.BackendKeyData) backendKeyData.Fields.ProcessID = c.pid backendKeyData.Fields.SecretKey = c.secret_key - _, err = backendKeyData.Write(c.wr) + err = c.Send(backendKeyData) if err != nil { return err } readyForQuery := new(protocol.ReadyForQuery) readyForQuery.Fields.Status = byte('I') - _, err = readyForQuery.Write(c.wr) + err = c.Send(readyForQuery) if err != nil { return err } go c.recvLoop() open := true for open { + err = c.Flush() + if err != nil { + return err + } open, err = c.tick(ctx) if !open { break @@ -348,7 +367,7 @@ func (c *Client) handle_cancel(ctx context.Context, p *protocol.StartupMessage) func (c *Client) tick(ctx context.Context) (bool, error) { var rsp protocol.Packet select { - case rsp = <-c.Recv(): + case rsp = <-c.recv: case <-ctx.Done(): return false, ctx.Err() } @@ -492,11 +511,19 @@ func (c *Client) GetPortal(name string) *protocol.Bind { } func (c *Client) Send(pkt protocol.Packet) error { + c.mu.Lock() + defer c.mu.Unlock() //log.Printf("sent packet(%s) %+v", reflect.TypeOf(pkt), pkt) _, err := pkt.Write(c.wr) return err } +func (c *Client) Flush() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.wr.Flush() +} + func (c *Client) Recv() <-chan protocol.Packet { return c.recv } diff --git a/lib/gat/gatling/conn_pool/server/server.go b/lib/gat/gatling/conn_pool/server/server.go index c30630878f5c905550972adc0af226a48e4449fb..4fa2df21b777d7bad449909ab6395bcefdcb435c 100644 --- a/lib/gat/gatling/conn_pool/server/server.go +++ b/lib/gat/gatling/conn_pool/server/server.go @@ -3,7 +3,6 @@ package server import ( "bufio" "fmt" - "io" "net" "reflect" "time" @@ -28,7 +27,7 @@ type Server struct { remote net.Addr conn net.Conn r *bufio.Reader - wr io.Writer + wr *bufio.Writer server_info []*protocol.ParameterStatus @@ -72,13 +71,9 @@ func Dial(ctx context.Context, if err != nil { return nil, err } - err = s.conn.(*net.TCPConn).SetNoDelay(false) - if err != nil { - return nil, err - } s.remote = s.conn.RemoteAddr() s.r = bufio.NewReader(s.conn) - s.wr = s.conn + s.wr = bufio.NewWriter(s.conn) s.user = *user s.db = db @@ -127,7 +122,7 @@ func (s *Server) startup(ctx context.Context) error { if err != nil { return err } - return nil + return s.flush() } func (s *Server) connect(ctx context.Context) error { @@ -164,17 +159,20 @@ func (s *Server) connect(ctx context.Context) error { if err != nil { return err } - func() { - rsp := new(protocol.AuthenticationResponse) - buf := bufpool.Get(len(scrm.Name()) + 1 + 4 + len(bts)) - buf.Reset() - defer bufpool.Put(buf) - _, _ = protocol.WriteString(buf, scrm.Name()) - _, _ = protocol.WriteInt32(buf, int32(len(bts))) - buf.Write(bts) - rsp.Fields.Data = buf.Bytes() - err = s.writePacket(rsp) - }() + + rsp := new(protocol.AuthenticationResponse) + buf := bufpool.Get(len(scrm.Name()) + 1 + 4 + len(bts)) + buf.Reset() + _, _ = protocol.WriteString(buf, scrm.Name()) + _, _ = protocol.WriteInt32(buf, int32(len(bts))) + buf.Write(bts) + rsp.Fields.Data = buf.Bytes() + err = s.writePacket(rsp) + bufpool.Put(buf) + if err != nil { + return err + } + err = s.flush() if err != nil { return err } @@ -189,6 +187,10 @@ func (s *Server) connect(ctx context.Context) error { if err != nil { return err } + err = s.flush() + if err != nil { + return err + } case 12: // SASL_FINAL s.log.Debug().Str("method", "scram256").Msg("sasl final") var done bool @@ -244,6 +246,10 @@ func (s *Server) writePacket(pkt protocol.Packet) error { return err } +func (s *Server) flush() error { + return s.wr.Flush() +} + func (s *Server) readPacket() (protocol.Packet, error) { return protocol.ReadBackend(s.r) } @@ -311,6 +317,7 @@ func (s *Server) destructPreparedStatement(name string) { query := new(protocol.Query) query.Fields.Query = fmt.Sprintf("DEALLOCATE \"%s\"", name) _ = s.writePacket(query) + _ = s.flush() // await server ready for { r, _ := s.readPacket() @@ -358,6 +365,10 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { if err != nil { return err } + err = s.flush() + if err != nil { + return err + } return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) @@ -386,6 +397,10 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { if err != nil { return err } + err = s.flush() + if err != nil { + return err + } return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) @@ -408,6 +423,10 @@ func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query strin if err != nil { return err } + err = s.flush() + if err != nil { + return err + } // this function seems wild but it has to be the way it is so we read the whole response, even if the // client fails midway @@ -435,6 +454,10 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin if err != nil { return err } + err = s.flush() + if err != nil { + return err + } return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) switch p := pkt.(type) { @@ -445,18 +468,22 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin err = client.Send(pkt) if err == nil { - select { - case r := <-client.Recv(): - //log.Printf("got client pkt pkt(%s): %+v", reflect.TypeOf(r), r) - switch r.(type) { - case *protocol.Query: - //forward to server - _ = s.writePacket(r) - default: - err = fmt.Errorf("expected an error in transaction state but got something else") + err = client.Flush() + if err == nil { + select { + case r := <-client.Recv(): + //log.Printf("got client pkt pkt(%s): %+v", reflect.TypeOf(r), r) + switch r.(type) { + case *protocol.Query: + //forward to server + _ = s.writePacket(r) + _ = s.flush() + default: + err = fmt.Errorf("expected an error in transaction state but got something else") + } + case <-ctx.Done(): + err = ctx.Err() } - case <-ctx.Done(): - err = ctx.Err() } } @@ -464,6 +491,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin end := new(protocol.Query) end.Fields.Query = "END;" _ = s.writePacket(end) + _ = s.flush() } } else { finish = true @@ -479,6 +507,10 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin } func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { + err := client.Flush() + if err != nil { + return err + } for { // detect a disconneted /hanging client by waiting 30 seoncds, else timeout // otherwise, just keep reading packets until a done or error is received @@ -489,10 +521,11 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { case pkt = <-client.Recv(): case <-cctx.Done(): _ = s.writePacket(new(protocol.CopyFail)) + _ = s.flush() return cctx.Err() } cancel() - err := s.writePacket(pkt) + err = s.writePacket(pkt) if err != nil { return err } @@ -500,7 +533,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { switch pkt.(type) { case *protocol.CopyDone, *protocol.CopyFail: // don't error on copyfail because the client is the one that errored, it already knows - return nil + return s.flush() } } } @@ -510,6 +543,10 @@ func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) if err != nil { return err } + err = s.flush() + if err != nil { + return err + } // read responses return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { switch pkt.(type) { diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index d969933937ade0a80d7e744979ab38cca28feca3..504055bbc75d31e77132c7cb23ee3fe3cecf9993 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -134,11 +134,6 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { // TODO: TLS func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { - err := c.(*net.TCPConn).SetNoDelay(false) - if err != nil { - return err - } - cl := client.NewClient(g, g.c, c, false) func() { @@ -152,10 +147,11 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { delete(g.clients, cl.Id()) }() - err = cl.Accept(ctx) + err := cl.Accept(ctx) if err != nil { log.Println("err in connection:", err.Error()) _ = cl.Send(pg_error.IntoPacket(err)) + _ = cl.Flush() } _ = c.Close() return nil diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index 2a62ab85de47287deef9451f242eb70e149b0750..86cb3bfacf68669f038e43905c3e848ec82506f2 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -19,6 +19,7 @@ type Client interface { SetCurrentConn(conn Connection) Send(pkt protocol.Packet) error + Flush() error Recv() <-chan protocol.Packet }