From 7d0aa7a0d3ce06e6c538ec7a122c8a81e03685e9 Mon Sep 17 00:00:00 2001 From: Garet Halliday <ghalliday@gfxlabs.io> Date: Fri, 30 Sep 2022 12:45:43 -0500 Subject: [PATCH] fix --- .../query_router/query_router_test.go | 2 +- lib/gat/gatling/client/client.go | 52 ++- lib/gat/gatling/server/server.go | 312 ++++++++---------- lib/gat/interfaces.go | 6 +- lib/gat/pool/session/pool.go | 6 +- lib/gat/pool/transaction/worker.go | 26 +- 6 files changed, 192 insertions(+), 212 deletions(-) diff --git a/lib/gat/database/query_router/query_router_test.go b/lib/gat/database/query_router/query_router_test.go index a68247b6..f94af3a9 100644 --- a/lib/gat/database/query_router/query_router_test.go +++ b/lib/gat/database/query_router/query_router_test.go @@ -8,7 +8,7 @@ import ( // TODO: adapt tests func TestQueryRouterInterRoleReplica(t *testing.T) { - qr := DefaultRouter + qr := DefaultRouter(nil) role, err := qr.InferRole(`UPDATE items SET name = 'pumpkin' WHERE id = 5`) if err != nil { t.Fatal(err) diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 20ac2ca5..d6d9163c 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -92,7 +92,6 @@ type Client struct { statements map[string]*protocol.Parse portals map[string]*protocol.Bind conf *config.Global - status rune parser *pg3p.Parser @@ -181,7 +180,6 @@ func NewClient( gatling: gatling, statements: make(map[string]*protocol.Parse), portals: make(map[string]*protocol.Bind), - status: 'I', conf: conf, parser: pg3p.NewParser(), } @@ -429,13 +427,11 @@ func (c *Client) Accept(ctx context.Context) error { return err } } - if c.status == 'I' { - rq := new(protocol.ReadyForQuery) - rq.Fields.Status = 'I' - err = c.Send(rq) - if err != nil { - return err - } + rq := new(protocol.ReadyForQuery) + rq.Fields.Status = 'I' + err = c.Send(rq) + if err != nil { + return err } } return nil @@ -450,8 +446,23 @@ func (c *Client) recvLoop() { } break } - log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) - c.recv <- recv + //log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) + switch pkt := recv.(type) { + case *protocol.Parse: + c.statements[pkt.Fields.PreparedStatement] = pkt + err = c.Send(new(protocol.ParseComplete)) + if err != nil { + break + } + case *protocol.Bind: + c.portals[pkt.Fields.Destination] = pkt + err = c.Send(new(protocol.BindComplete)) + if err != nil { + break + } + default: + c.recv <- recv + } } } @@ -479,16 +490,11 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return false, ctx.Err() } switch cast := rsp.(type) { - case *protocol.Parse: - return true, c.parse(ctx, cast) - case *protocol.Bind: - return true, c.bind(ctx, cast) case *protocol.Describe: return true, c.handle_describe(ctx, cast) case *protocol.Execute: return true, c.handle_execute(ctx, cast) case *protocol.Sync: - c.status = 'I' return true, nil case *protocol.Query: return true, c.handle_query(ctx, cast) @@ -502,28 +508,14 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return true, nil } -func (c *Client) parse(ctx context.Context, q *protocol.Parse) error { - c.statements[q.Fields.PreparedStatement] = q - c.status = 'T' - return c.Send(new(protocol.ParseComplete)) -} - -func (c *Client) bind(ctx context.Context, b *protocol.Bind) error { - c.portals[b.Fields.Destination] = b - c.status = 'T' - return c.Send(new(protocol.BindComplete)) -} - func (c *Client) handle_describe(ctx context.Context, d *protocol.Describe) error { //log.Println("describe") - c.status = 'T' c.startRequest() return c.server.Describe(ctx, c, d) } func (c *Client) handle_execute(ctx context.Context, e *protocol.Execute) error { //log.Println("execute") - c.status = 'T' c.startRequest() return c.server.Execute(ctx, c, e) } diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 667991cd..7ef62ee3 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -2,7 +2,6 @@ package server import ( "bufio" - "errors" "fmt" "net" "reflect" @@ -53,7 +52,8 @@ type Server struct { dbpass string user config.User - healthy bool + healthy bool + awaitingSync bool log zlog.Logger @@ -387,6 +387,14 @@ func (s *Server) readPacket() (protocol.Packet, error) { return p, err } +func (s *Server) stabilize() { + // TODO actually stabilize connection + if s.awaitingSync { + _ = s.writePacket(new(protocol.Sync)) + _ = s.flush() + } +} + func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { // send prepared statement stmt := client.GetPreparedStatement(name) @@ -471,183 +479,174 @@ func (s *Server) destructPortal(name string) { s.destructPreparedStatement(portal.Fields.PreparedStatement) } -func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { - switch d.Fields.Which { - case 'S': // prepared statement - err := s.ensurePreparedStatement(client, d.Fields.Name) +func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { + return s.sendAndLink(ctx, client, d) +} + +func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { + switch pkt := packet.(type) { + case *protocol.FunctionCall, *protocol.Query: + err := s.writePacket(packet) if err != nil { return err } - case 'P': // portal - err := s.ensurePortal(client, d.Fields.Name) + err = s.flush() if err != nil { return err } - default: - return &pg_error.Error{ - Severity: pg_error.Err, - Code: pg_error.ProtocolViolation, - Message: fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", d.Fields.Which), + case *protocol.Describe: + s.awaitingSync = true + switch pkt.Fields.Which { + case 'S': // prepared statement + err := s.ensurePreparedStatement(client, pkt.Fields.Name) + if err != nil { + return err + } + case 'P': // portal + err := s.ensurePortal(client, pkt.Fields.Name) + if err != nil { + return err + } + default: + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", pkt.Fields.Which), + } } - } - // now we actually execute the thing the client wants - err := s.writePacket(d) - if err != nil { - return err + // now we actually execute the thing the client wants + err := s.writePacket(packet) + if err != nil { + return err + } + case *protocol.Execute: + s.awaitingSync = true + err := s.ensurePortal(client, pkt.Fields.Name) + if err != nil { + return err + } + + err = s.writePacket(pkt) + if err != nil { + return err + } + case *protocol.Sync: + s.awaitingSync = false + err := s.writePacket(packet) + if err != nil { + return err + } + err = s.flush() + if err != nil { + return err + } } - err = s.writePacket(new(protocol.Sync)) + return nil +} + +func (s *Server) sendAndLink(ctx context.Context, client gat.Client, initial protocol.Packet) error { + err := s.handleRecv(client, initial) if err != nil { return err } - err = s.flush() + err = s.awaitSync(ctx, client) if err != nil { return err } + return s.link(ctx, client) +} - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) - switch pkt.(type) { +func (s *Server) link(ctx context.Context, client gat.Client) error { + defer s.stabilize() + for { + pkt, err := s.readPacket() + if err != nil { + return err + } + + switch p := pkt.(type) { case *protocol.BindComplete, *protocol.ParseComplete: + // ignore, it is because we bound stuff case *protocol.ReadyForQuery: - finish = true + if p.Fields.Status == 'I' { + // this client is done + return nil + } + + err = client.Send(p) + if err != nil { + return err + } + err = client.Flush() + if err != nil { + return err + } + + err = s.handleClientPacket(ctx, client) + if err != nil { + return err + } + err = s.awaitSync(ctx, client) + if err != nil { + return err + } + case *protocol.CopyInResponse: + err = client.Send(p) + if err != nil { + return err + } + err = client.Flush() + if err != nil { + return err + } + err = s.CopyIn(ctx, client) + if err != nil { + return err + } default: - forward = true + err = client.Send(p) + if err != nil { + return err + } } - return - }) + } } -func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { - log.Printf("execute `%s`", e.Fields.Name) - err := s.ensurePortal(client, e.Fields.Name) - if err != nil { - return err +func (s *Server) handleClientPacket(ctx context.Context, client gat.Client) error { + select { + case pkt := <-client.Recv(): + return s.handleRecv(client, pkt) + case <-ctx.Done(): + return ctx.Err() } +} - err = s.writePacket(e) - if err != nil { - return err - } - err = s.writePacket(new(protocol.Sync)) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err +func (s *Server) awaitSync(ctx context.Context, client gat.Client) error { + for s.awaitingSync { + err := s.handleClientPacket(ctx, client) + if err != nil { + return err + } } + return nil +} - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) - switch p := pkt.(type) { - case *protocol.BindComplete, *protocol.ParseComplete: - case *protocol.ReadyForQuery: - if p.Fields.Status != 'I' { - err = errors.New("transactions are not allowed in statements") - - end := new(protocol.Query) - end.Fields.Query = "END" - _ = s.writePacket(end) - _ = s.flush() - } else { - finish = true - } - default: - forward = true - } - return - }) +func (s *Server) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { + return s.sendAndLink(ctx, client, e) } func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query string) error { // send to server q := new(protocol.Query) q.Fields.Query = query - err := s.writePacket(q) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err - } - - // this function seems wild but it has to be the way it is so we read the whole response, even if the - // client fails midway - // read responses - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Printf("forwarding pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) - switch pkt.(type) { - case *protocol.ReadyForQuery: - // all ReadyForQuery packets end a simple query, regardless of type - finish = true - case *protocol.CopyInResponse: - _ = client.Send(pkt) - err = s.CopyIn(ctx, client) - default: - forward = true - } - return - }) + return s.sendAndLink(ctx, client, q) } func (s *Server) Transaction(ctx context.Context, client gat.Client, query string) error { q := new(protocol.Query) q.Fields.Query = query - err := s.writePacket(q) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err - } - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) - switch p := pkt.(type) { - case *protocol.ReadyForQuery: - if p.Fields.Status != 'I' { - // send to client and wait for next query - err = client.Send(pkt) - - if err == nil { - err = client.Flush() - if err == nil { - select { - case r := <-client.Recv(): - //log.Printf("got client pkt pkt(%s): %+v", reflect.TypeOf(r), r) - switch r.(type) { - case *protocol.Query: - //forward to server - _ = s.writePacket(r) - _ = s.flush() - default: - err = fmt.Errorf("expected a query in transaction state but got something else") - } - case <-ctx.Done(): - err = ctx.Err() - } - } - } - - if err != nil { - end := new(protocol.Query) - end.Fields.Query = "END" - _ = s.writePacket(end) - _ = s.flush() - } - } else { - finish = true - } - case *protocol.CopyInResponse: - _ = client.Send(pkt) - err = s.CopyIn(ctx, client) - default: - forward = true - } - return - }) + return s.sendAndLink(ctx, client, q) } func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { @@ -656,19 +655,15 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { return err } for { - // detect a disconneted /hanging client by waiting 30 seoncds, else timeout - // otherwise, just keep reading packets until a done or error is received - cctx, cancel := context.WithTimeout(ctx, 30*time.Second) var pkt protocol.Packet // receive a packet, or done if the ctx gets canceled select { case pkt = <-client.Recv(): - case <-cctx.Done(): + case <-ctx.Done(): _ = s.writePacket(new(protocol.CopyFail)) _ = s.flush() - return cctx.Err() + return ctx.Err() } - cancel() err = s.writePacket(pkt) if err != nil { return err @@ -682,25 +677,8 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { } } -func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) error { - err := s.writePacket(payload) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err - } - // read responses - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - switch pkt.(type) { - case *protocol.ReadyForQuery: // status 'I' should only be encountered here - finish = true - default: - forward = true - } - return - }) +func (s *Server) CallFunction(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { + return s.sendAndLink(ctx, client, payload) } func (s *Server) Close(ctx context.Context) error { diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index a8af8c35..fc4ab08e 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -130,9 +130,9 @@ type Connection interface { IsCloseNeeded() bool // actions - Describe(client Client, payload *protocol.Describe) error - Execute(client Client, payload *protocol.Execute) error - CallFunction(client Client, payload *protocol.FunctionCall) error + Describe(ctx context.Context, client Client, payload *protocol.Describe) error + Execute(ctx context.Context, client Client, payload *protocol.Execute) error + CallFunction(ctx context.Context, client Client, payload *protocol.FunctionCall) error SimpleQuery(ctx context.Context, client Client, payload string) error Transaction(ctx context.Context, client Client, payload string) error diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go index 24986a8a..f58a256b 100644 --- a/lib/gat/pool/session/pool.go +++ b/lib/gat/pool/session/pool.go @@ -98,7 +98,7 @@ func (p *Pool) Describe(ctx context.Context, client gat.Client, describe *protoc if err != nil { return err } - return c.Describe(client, describe) + return c.Describe(ctx, client, describe) } func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol.Execute) error { @@ -106,7 +106,7 @@ func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol if err != nil { return err } - return c.Execute(client, execute) + return c.Execute(ctx, client, execute) } func (p *Pool) SimpleQuery(ctx context.Context, client gat.Client, query string) error { @@ -130,7 +130,7 @@ func (p *Pool) CallFunction(ctx context.Context, client gat.Client, payload *pro if err != nil { return err } - return c.CallFunction(client, payload) + return c.CallFunction(ctx, client, payload) } var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/pool/transaction/worker.go b/lib/gat/pool/transaction/worker.go index 94a3083a..f1ff43d9 100644 --- a/lib/gat/pool/transaction/worker.go +++ b/lib/gat/pool/transaction/worker.go @@ -104,7 +104,9 @@ func (w *worker) HandleDescribe(ctx context.Context, c gat.Client, d *protocol.D defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } errch := make(chan error) @@ -128,7 +130,9 @@ func (w *worker) HandleExecute(ctx context.Context, c gat.Client, e *protocol.Ex defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } errch := make(chan error) @@ -152,7 +156,9 @@ func (w *worker) HandleFunction(ctx context.Context, c gat.Client, fn *protocol. defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } errch := make(chan error) @@ -176,7 +182,9 @@ func (w *worker) HandleSimpleQuery(ctx context.Context, c gat.Client, query stri defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } start := time.Now() @@ -206,7 +214,9 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } start := time.Now() @@ -259,7 +269,7 @@ func (w *worker) z_actually_do_describe(ctx context.Context, client gat.Client, } w.setCurrentBinding(client, target) defer w.unsetCurrentBinding(client, target) - return target.Describe(client, payload) + return target.Describe(ctx, client, payload) } func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, payload *protocol.Execute) error { srv := w.chooseShard(client) @@ -299,7 +309,7 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p if target == nil { return fmt.Errorf("describe('%+v') fail: no server", payload) } - return target.Execute(client, payload) + return target.Execute(ctx, client, payload) } func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { srv := w.chooseShard(client) @@ -316,7 +326,7 @@ func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payloa } w.setCurrentBinding(client, target) defer w.unsetCurrentBinding(client, target) - err := target.CallFunction(client, payload) + err := target.CallFunction(ctx, client, payload) if err != nil { return fmt.Errorf("fn('%+v') fail: %w ", payload, err) } -- GitLab