diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 66f5aaa05427958dac98ec6305ae2f1e867b40c4..d398dcc670aaf0eff4fb306bf21036c3e23ac8d6 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -326,6 +326,10 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return true, c.handle_describe(ctx, cast) case *protocol.Execute: return true, c.handle_execute(ctx, cast) + case *protocol.Sync: + pkt := new(protocol.ReadyForQuery) + pkt.Fields.Status = 'I' + return true, c.Send(pkt) case *protocol.Query: return true, c.handle_query(ctx, cast) case *protocol.FunctionCall: @@ -333,6 +337,7 @@ func (c *Client) tick(ctx context.Context) (bool, error) { case *protocol.Terminate: return false, nil default: + log.Printf("unhandled packet %#v", rsp) } return true, nil } @@ -421,13 +426,11 @@ func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { } func (c *Client) handle_simple_query(ctx context.Context, q string) error { - log.Println("query", q) - //log.Println("query: ", q.Fields.Query) + //log.Println("query:", q) return c.server.SimpleQuery(ctx, c, q) } func (c *Client) handle_transaction(ctx context.Context, q string) error { - log.Println("transaction", q) return c.server.Transaction(ctx, c, q) } @@ -439,6 +442,14 @@ func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) return err } +func (c *Client) GetPreparedStatement(name string) *protocol.Parse { + return c.statements[name] +} + +func (c *Client) GetPortal(name string) *protocol.Bind { + return c.portals[name] +} + func (c *Client) Send(pkt protocol.Packet) error { //log.Printf("sent packet(%s) %+v", reflect.TypeOf(pkt), pkt) _, err := pkt.Write(c.bufwr) diff --git a/lib/gat/gatling/conn_pool/server/server.go b/lib/gat/gatling/conn_pool/server/server.go index 90249bc29ff408699f48e1be38cc4d2d3aba23cd..cca8392a75d41fe9641d7f3065ea82d81374cb9f 100644 --- a/lib/gat/gatling/conn_pool/server/server.go +++ b/lib/gat/gatling/conn_pool/server/server.go @@ -121,7 +121,7 @@ func (s *Server) startup(ctx context.Context) error { }, {}, } - _, err := start.Write(s.wr) + err := s.writePacket(start) if err != nil { return err } @@ -171,7 +171,7 @@ func (s *Server) connect(ctx context.Context) error { _, _ = protocol.WriteInt32(buf, int32(len(bts))) buf.Write(bts) rsp.Fields.Data = buf.Bytes() - _, err = rsp.Write(s.wr) + err = s.writePacket(rsp) }() if err != nil { return err @@ -183,7 +183,7 @@ func (s *Server) connect(ctx context.Context) error { rsp := new(protocol.AuthenticationResponse) rsp.Fields.Data = bts - _, err = rsp.Write(s.wr) + err = s.writePacket(rsp) if err != nil { return err } @@ -248,6 +248,160 @@ func (s *Server) writePacket(pkt protocol.Packet) error { return s.bufwr.Flush() } +func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { + // send prepared statement + stmt := client.GetPreparedStatement(name) + if stmt == nil { + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("prepared statement '%s' does not exist", name), + } + } + + // send prepared statement to server + err := s.writePacket(stmt) + if err != nil { + return err + } + + /*log.Println("wait for server to accept prepared statement") + // make sure server accepted it + var rsp protocol.Packet + rsp, err = protocol.ReadBackend(s.r) + if err != nil { + return err + } + log.Println("received from server", rsp) + if _, ok := rsp.(*protocol.ParseComplete); !ok { + return fmt.Errorf("backend failed to parse prepared statement: %+v", rsp) + }*/ + + return nil +} + +func (s *Server) ensurePortal(client gat.Client, name string) error { + portal := client.GetPortal(name) + if portal == nil { + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("portal '%s' does not exist", name), + } + } + + err := s.ensurePreparedStatement(client, portal.Fields.PreparedStatement) + if err != nil { + return err + } + + err = s.writePacket(portal) + if err != nil { + return err + } + + /*var rsp protocol.Packet + rsp, err = protocol.ReadBackend(s.r) + if err != nil { + return err + } + if _, ok := rsp.(*protocol.BindComplete); !ok { + return fmt.Errorf("backend failed to bind portal: %+v", rsp) + }*/ + + return nil +} + +func (s *Server) destructPreparedStatement(client gat.Client, name string) { + query := new(protocol.Query) + query.Fields.Query = fmt.Sprintf("DEALLOCATE \"%s\"", name) + _ = s.writePacket(query) + // await server ready + for { + r, _ := protocol.ReadBackend(s.r) + if _, ok := r.(*protocol.ReadyForQuery); ok { + return + } + } +} + +func (s *Server) destructPortal(client gat.Client, name string) { + portal := client.GetPortal(name) + s.destructPreparedStatement(client, portal.Fields.PreparedStatement) +} + +func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { + // TODO for now, we're actually just going to send the query and it's binding + // TODO(Garet) keep track of which connections have which prepared statements and portals + switch d.Fields.Which { + case 'S': // prepared statement + err := s.ensurePreparedStatement(client, d.Fields.Name) + if err != nil { + return err + } + defer s.destructPreparedStatement(client, d.Fields.Name) + case 'P': // portal + err := s.ensurePortal(client, d.Fields.Name) + if err != nil { + return err + } + defer s.destructPortal(client, d.Fields.Name) + default: + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", d.Fields.Which), + } + } + + // now we actually execute the thing the client wants + err := s.writePacket(d) + if err != nil { + return err + } + err = s.writePacket(new(protocol.Sync)) + if err != nil { + return err + } + + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + switch pkt.(type) { + case *protocol.BindComplete, *protocol.ParseComplete: + return false, false + case *protocol.ReadyForQuery: + return false, true + default: + return true, false + } + }) +} + +func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { + err := s.ensurePortal(client, e.Fields.Name) + if err != nil { + return err + } + defer s.destructPortal(client, e.Fields.Name) + + err = s.writePacket(e) + if err != nil { + return err + } + err = s.writePacket(new(protocol.Sync)) + if err != nil { + return err + } + + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + switch pkt.(type) { + case *protocol.ReadyForQuery: + return false, true + default: + return true, false + } + }) +} + func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query string) error { // send to server q := new(protocol.Query) @@ -316,8 +470,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin switch r.(type) { case *protocol.Query: //forward to server - _, _ = r.Write(s.bufwr) - err = s.bufwr.Flush() + _ = s.writePacket(r) default: err = fmt.Errorf("expected an error in transaction state but got something else") } @@ -329,8 +482,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin if err != nil { end := new(protocol.Query) end.Fields.Query = "END;" - _, _ = end.Write(s.bufwr) - _ = s.bufwr.Flush() + _ = s.writePacket(end) } } return p.Fields.Status == 'I', p.Fields.Status == 'I' @@ -364,8 +516,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { select { case pkt = <-client.Recv(): case <-cctx.Done(): - _, _ = new(protocol.CopyFail).Write(s.bufwr) - _ = s.bufwr.Flush() + _ = s.writePacket(new(protocol.CopyFail)) rfq := new(protocol.ReadyForQuery) rfq.Fields.Status = 'I' return client.Send(rfq) @@ -386,7 +537,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { } func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) error { - _, err := payload.Write(s.wr) + err := s.writePacket(payload) if err != nil { return err } diff --git a/lib/gat/gatling/conn_pool/worker.go b/lib/gat/gatling/conn_pool/worker.go index 2254db7013f06922aae83d72af397fca08d91e20..f8e45bea30173646356c02a3d10d2e94346a5b6b 100644 --- a/lib/gat/gatling/conn_pool/worker.go +++ b/lib/gat/gatling/conn_pool/worker.go @@ -3,6 +3,7 @@ package conn_pool import ( "context" "fmt" + "gfx.cafe/gfx/pggat/lib/config" "log" "gfx.cafe/gfx/pggat/lib/gat" @@ -111,12 +112,10 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri errch := make(chan error) go func() { defer close(errch) - //log.Println("performing transaction...") select { case errch <- w.z_actually_do_transaction(ctx, c, query): case <-ctx.Done(): } - //log.Println("done", err) }() // wait until query or close @@ -129,10 +128,36 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri } func (w *worker) z_actually_do_describe(ctx context.Context, client gat.Client, payload *protocol.Describe) error { - return nil + c := w.w + srv := c.chooseConnections() + if srv == nil { + return fmt.Errorf("describe('%+v') fail: no server", payload) + } + defer srv.mu.Unlock() + // describe the portal + // we can use a replica because we are just describing what this query will return, query content doesn't matter + // because nothing is actually executed yet + target := srv.choose(config.SERVERROLE_REPLICA) + if target == nil { + return fmt.Errorf("describe('%+v') fail: no server", payload) + } + return target.Describe(client, payload) } func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, payload *protocol.Execute) error { - return nil + c := w.w + srv := c.chooseConnections() + if srv == nil { + return fmt.Errorf("describe('%+v') fail: no server", payload) + } + defer srv.mu.Unlock() + // execute the query + // for now, use primary + // TODO read the query of the underlying prepared statement and choose server accordingly + target := srv.primary + if target == nil { + return fmt.Errorf("describe('%+v') fail: no server", payload) + } + return target.Execute(client, payload) } func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { c := w.w diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index 9e8d4e187a03dd0dc01714f5cf7cf14b79a7e33a..a2e2795396cc60cb8acb39a72bf59f37a47f9824 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -123,7 +123,7 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { cl := client.NewClient(g, g.c, c, false) err := cl.Accept(ctx) if err != nil { - log.Println(err.Error()) + log.Println("err in connection:", err.Error()) switch e := err.(type) { case *pg_error.Error: _ = cl.Send(e.Packet()) diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index 8f2e2249018e5d715db8db57fbb3affe3f5608f9..5a139244e19370bb777421fff84bead959d95d4a 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -8,6 +8,9 @@ import ( ) type Client interface { + GetPreparedStatement(name string) *protocol.Parse + GetPortal(name string) *protocol.Bind + Send(pkt protocol.Packet) error Recv() <-chan protocol.Packet }