diff --git a/lib/gat/admin/admin.go b/lib/gat/admin/admin.go index 4becc7f97f18931835d2fa033c08ed58f12a9e3c..8780ac73eb4304589d9799b00afcfc3905b90a8c 100644 --- a/lib/gat/admin/admin.go +++ b/lib/gat/admin/admin.go @@ -533,6 +533,10 @@ func (p *Database) GetRouter() gat.QueryRouter { return nil } +func (p *Database) GetName() string { + return "pggat" +} + func (p *Database) WithUser(name string) gat.Pool { conf := p.gat.GetConfig() if name != conf.General.AdminUsername { diff --git a/lib/gat/database/database.go b/lib/gat/database/database.go index ec044f0c17a3d86ee9975a10ea28e21309e839fc..a54060b6a885b77f1736a728787e6a30d3bf016a 100644 --- a/lib/gat/database/database.go +++ b/lib/gat/database/database.go @@ -14,6 +14,7 @@ type Database struct { c *config.Pool users map[string]config.User connPools map[string]gat.Pool + name string stats *gat.PoolStats @@ -24,11 +25,12 @@ type Database struct { mu sync.RWMutex } -func New(dialer gat.Dialer, conf *config.Pool) *Database { +func New(dialer gat.Dialer, name string, conf *config.Pool) *Database { pool := &Database{ connPools: make(map[string]gat.Pool), stats: gat.NewPoolStats(), router: query_router.DefaultRouter(conf), + name: name, dialer: dialer, } @@ -74,6 +76,10 @@ func (p *Database) GetRouter() gat.QueryRouter { return p.router } +func (p *Database) GetName() string { + return p.name +} + func (p *Database) WithUser(name string) gat.Pool { p.mu.RLock() defer p.mu.RUnlock() diff --git a/lib/gat/database/query_router/query_router_test.go b/lib/gat/database/query_router/query_router_test.go index a68247b611d67af69b09168419f96e67dd8e9097..f94af3a9323742a9f75d0701cf7fc4520dcf9c6b 100644 --- a/lib/gat/database/query_router/query_router_test.go +++ b/lib/gat/database/query_router/query_router_test.go @@ -8,7 +8,7 @@ import ( // TODO: adapt tests func TestQueryRouterInterRoleReplica(t *testing.T) { - qr := DefaultRouter + qr := DefaultRouter(nil) role, err := qr.InferRole(`UPDATE items SET name = 'pumpkin' WHERE id = 5`) if err != nil { t.Fatal(err) diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 393ccb8deb84ac860b9286f0eeb0d113a110b14d..c86501514111709645d98bf64cb631450b4f4e8d 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -94,7 +94,6 @@ type Client struct { statements map[string]*protocol.Parse portals map[string]*protocol.Bind conf *config.Global - status rune parser *pg3p.Parser @@ -183,7 +182,6 @@ func NewClient( gatling: gatling, statements: make(map[string]*protocol.Parse), portals: make(map[string]*protocol.Bind), - status: 'I', conf: conf, parser: pg3p.NewParser(), } @@ -427,13 +425,11 @@ func (c *Client) Accept(ctx context.Context) error { return err } } - if c.status == 'I' { - rq := new(protocol.ReadyForQuery) - rq.Fields.Status = 'I' - err = c.Send(rq) - if err != nil { - return err - } + rq := new(protocol.ReadyForQuery) + rq.Fields.Status = 'I' + err = c.Send(rq) + if err != nil { + return err } } return nil @@ -448,8 +444,23 @@ func (c *Client) recvLoop() { } break } - log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) - c.recv <- recv + //log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) + switch pkt := recv.(type) { + case *protocol.Parse: + c.statements[pkt.Fields.PreparedStatement] = pkt + err = c.Send(new(protocol.ParseComplete)) + if err != nil { + break + } + case *protocol.Bind: + c.portals[pkt.Fields.Destination] = pkt + err = c.Send(new(protocol.BindComplete)) + if err != nil { + break + } + default: + c.recv <- recv + } } } @@ -477,16 +488,11 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return false, ctx.Err() } switch cast := rsp.(type) { - case *protocol.Parse: - return true, c.parse(ctx, cast) - case *protocol.Bind: - return true, c.bind(ctx, cast) case *protocol.Describe: return true, c.handle_describe(ctx, cast) case *protocol.Execute: return true, c.handle_execute(ctx, cast) case *protocol.Sync: - c.status = 'I' return true, nil case *protocol.Query: return true, c.handle_query(ctx, cast) @@ -500,28 +506,14 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return true, nil } -func (c *Client) parse(ctx context.Context, q *protocol.Parse) error { - c.statements[q.Fields.PreparedStatement] = q - c.status = 'T' - return c.Send(new(protocol.ParseComplete)) -} - -func (c *Client) bind(ctx context.Context, b *protocol.Bind) error { - c.portals[b.Fields.Destination] = b - c.status = 'T' - return c.Send(new(protocol.BindComplete)) -} - func (c *Client) handle_describe(ctx context.Context, d *protocol.Describe) error { //log.Println("describe") - c.status = 'T' c.startRequest() return c.server.Describe(ctx, c, d) } func (c *Client) handle_execute(ctx context.Context, e *protocol.Execute) error { //log.Println("execute") - c.status = 'T' c.startRequest() return c.server.Execute(ctx, c, e) } diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index 6ec0dd212a6ac0e35bfba7af6173fc57bd6b8f33..5948f1cb35799ccdef75c7c7e804fcc72bd66000 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -142,7 +142,7 @@ func (g *Gatling) ensurePools(c *config.Global) error { if existing, ok := g.pools[name]; ok { existing.EnsureConfig(p) } else { - g.pools[name] = database.New(server.Dial, p) + g.pools[name] = database.New(server.Dial, name, p) } } return nil diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 667991cd46ea46c14cfe4b1bbf7e975a676fa5ca..7ef62ee337c6bf267256cd4683c927c3b1789dcc 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -2,7 +2,6 @@ package server import ( "bufio" - "errors" "fmt" "net" "reflect" @@ -53,7 +52,8 @@ type Server struct { dbpass string user config.User - healthy bool + healthy bool + awaitingSync bool log zlog.Logger @@ -387,6 +387,14 @@ func (s *Server) readPacket() (protocol.Packet, error) { return p, err } +func (s *Server) stabilize() { + // TODO actually stabilize connection + if s.awaitingSync { + _ = s.writePacket(new(protocol.Sync)) + _ = s.flush() + } +} + func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { // send prepared statement stmt := client.GetPreparedStatement(name) @@ -471,183 +479,174 @@ func (s *Server) destructPortal(name string) { s.destructPreparedStatement(portal.Fields.PreparedStatement) } -func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { - switch d.Fields.Which { - case 'S': // prepared statement - err := s.ensurePreparedStatement(client, d.Fields.Name) +func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { + return s.sendAndLink(ctx, client, d) +} + +func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error { + switch pkt := packet.(type) { + case *protocol.FunctionCall, *protocol.Query: + err := s.writePacket(packet) if err != nil { return err } - case 'P': // portal - err := s.ensurePortal(client, d.Fields.Name) + err = s.flush() if err != nil { return err } - default: - return &pg_error.Error{ - Severity: pg_error.Err, - Code: pg_error.ProtocolViolation, - Message: fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", d.Fields.Which), + case *protocol.Describe: + s.awaitingSync = true + switch pkt.Fields.Which { + case 'S': // prepared statement + err := s.ensurePreparedStatement(client, pkt.Fields.Name) + if err != nil { + return err + } + case 'P': // portal + err := s.ensurePortal(client, pkt.Fields.Name) + if err != nil { + return err + } + default: + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", pkt.Fields.Which), + } } - } - // now we actually execute the thing the client wants - err := s.writePacket(d) - if err != nil { - return err + // now we actually execute the thing the client wants + err := s.writePacket(packet) + if err != nil { + return err + } + case *protocol.Execute: + s.awaitingSync = true + err := s.ensurePortal(client, pkt.Fields.Name) + if err != nil { + return err + } + + err = s.writePacket(pkt) + if err != nil { + return err + } + case *protocol.Sync: + s.awaitingSync = false + err := s.writePacket(packet) + if err != nil { + return err + } + err = s.flush() + if err != nil { + return err + } } - err = s.writePacket(new(protocol.Sync)) + return nil +} + +func (s *Server) sendAndLink(ctx context.Context, client gat.Client, initial protocol.Packet) error { + err := s.handleRecv(client, initial) if err != nil { return err } - err = s.flush() + err = s.awaitSync(ctx, client) if err != nil { return err } + return s.link(ctx, client) +} - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) - switch pkt.(type) { +func (s *Server) link(ctx context.Context, client gat.Client) error { + defer s.stabilize() + for { + pkt, err := s.readPacket() + if err != nil { + return err + } + + switch p := pkt.(type) { case *protocol.BindComplete, *protocol.ParseComplete: + // ignore, it is because we bound stuff case *protocol.ReadyForQuery: - finish = true + if p.Fields.Status == 'I' { + // this client is done + return nil + } + + err = client.Send(p) + if err != nil { + return err + } + err = client.Flush() + if err != nil { + return err + } + + err = s.handleClientPacket(ctx, client) + if err != nil { + return err + } + err = s.awaitSync(ctx, client) + if err != nil { + return err + } + case *protocol.CopyInResponse: + err = client.Send(p) + if err != nil { + return err + } + err = client.Flush() + if err != nil { + return err + } + err = s.CopyIn(ctx, client) + if err != nil { + return err + } default: - forward = true + err = client.Send(p) + if err != nil { + return err + } } - return - }) + } } -func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { - log.Printf("execute `%s`", e.Fields.Name) - err := s.ensurePortal(client, e.Fields.Name) - if err != nil { - return err +func (s *Server) handleClientPacket(ctx context.Context, client gat.Client) error { + select { + case pkt := <-client.Recv(): + return s.handleRecv(client, pkt) + case <-ctx.Done(): + return ctx.Err() } +} - err = s.writePacket(e) - if err != nil { - return err - } - err = s.writePacket(new(protocol.Sync)) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err +func (s *Server) awaitSync(ctx context.Context, client gat.Client) error { + for s.awaitingSync { + err := s.handleClientPacket(ctx, client) + if err != nil { + return err + } } + return nil +} - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) - switch p := pkt.(type) { - case *protocol.BindComplete, *protocol.ParseComplete: - case *protocol.ReadyForQuery: - if p.Fields.Status != 'I' { - err = errors.New("transactions are not allowed in statements") - - end := new(protocol.Query) - end.Fields.Query = "END" - _ = s.writePacket(end) - _ = s.flush() - } else { - finish = true - } - default: - forward = true - } - return - }) +func (s *Server) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { + return s.sendAndLink(ctx, client, e) } func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query string) error { // send to server q := new(protocol.Query) q.Fields.Query = query - err := s.writePacket(q) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err - } - - // this function seems wild but it has to be the way it is so we read the whole response, even if the - // client fails midway - // read responses - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Printf("forwarding pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) - switch pkt.(type) { - case *protocol.ReadyForQuery: - // all ReadyForQuery packets end a simple query, regardless of type - finish = true - case *protocol.CopyInResponse: - _ = client.Send(pkt) - err = s.CopyIn(ctx, client) - default: - forward = true - } - return - }) + return s.sendAndLink(ctx, client, q) } func (s *Server) Transaction(ctx context.Context, client gat.Client, query string) error { q := new(protocol.Query) q.Fields.Query = query - err := s.writePacket(q) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err - } - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - //log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) - switch p := pkt.(type) { - case *protocol.ReadyForQuery: - if p.Fields.Status != 'I' { - // send to client and wait for next query - err = client.Send(pkt) - - if err == nil { - err = client.Flush() - if err == nil { - select { - case r := <-client.Recv(): - //log.Printf("got client pkt pkt(%s): %+v", reflect.TypeOf(r), r) - switch r.(type) { - case *protocol.Query: - //forward to server - _ = s.writePacket(r) - _ = s.flush() - default: - err = fmt.Errorf("expected a query in transaction state but got something else") - } - case <-ctx.Done(): - err = ctx.Err() - } - } - } - - if err != nil { - end := new(protocol.Query) - end.Fields.Query = "END" - _ = s.writePacket(end) - _ = s.flush() - } - } else { - finish = true - } - case *protocol.CopyInResponse: - _ = client.Send(pkt) - err = s.CopyIn(ctx, client) - default: - forward = true - } - return - }) + return s.sendAndLink(ctx, client, q) } func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { @@ -656,19 +655,15 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { return err } for { - // detect a disconneted /hanging client by waiting 30 seoncds, else timeout - // otherwise, just keep reading packets until a done or error is received - cctx, cancel := context.WithTimeout(ctx, 30*time.Second) var pkt protocol.Packet // receive a packet, or done if the ctx gets canceled select { case pkt = <-client.Recv(): - case <-cctx.Done(): + case <-ctx.Done(): _ = s.writePacket(new(protocol.CopyFail)) _ = s.flush() - return cctx.Err() + return ctx.Err() } - cancel() err = s.writePacket(pkt) if err != nil { return err @@ -682,25 +677,8 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { } } -func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) error { - err := s.writePacket(payload) - if err != nil { - return err - } - err = s.flush() - if err != nil { - return err - } - // read responses - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { - switch pkt.(type) { - case *protocol.ReadyForQuery: // status 'I' should only be encountered here - finish = true - default: - forward = true - } - return - }) +func (s *Server) CallFunction(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { + return s.sendAndLink(ctx, client, payload) } func (s *Server) Close(ctx context.Context) error { diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index a8af8c35b327aa95ab24aaefcf762d6555f377ec..6c13c30c8da7a543b6f2c1432cf925232d8981cb 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -63,6 +63,7 @@ type Gat interface { type Database interface { GetUser(name string) *config.User GetRouter() QueryRouter + GetName() string WithUser(name string) Pool GetPools() []Pool @@ -130,9 +131,9 @@ type Connection interface { IsCloseNeeded() bool // actions - Describe(client Client, payload *protocol.Describe) error - Execute(client Client, payload *protocol.Execute) error - CallFunction(client Client, payload *protocol.FunctionCall) error + Describe(ctx context.Context, client Client, payload *protocol.Describe) error + Execute(ctx context.Context, client Client, payload *protocol.Execute) error + CallFunction(ctx context.Context, client Client, payload *protocol.FunctionCall) error SimpleQuery(ctx context.Context, client Client, payload string) error Transaction(ctx context.Context, client Client, payload string) error diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go index 24986a8ab5a0b99791be47358fd319f0e89ea32b..f58a256b3e745cc353957a3cab1e2a3fa1c741f6 100644 --- a/lib/gat/pool/session/pool.go +++ b/lib/gat/pool/session/pool.go @@ -98,7 +98,7 @@ func (p *Pool) Describe(ctx context.Context, client gat.Client, describe *protoc if err != nil { return err } - return c.Describe(client, describe) + return c.Describe(ctx, client, describe) } func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol.Execute) error { @@ -106,7 +106,7 @@ func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol if err != nil { return err } - return c.Execute(client, execute) + return c.Execute(ctx, client, execute) } func (p *Pool) SimpleQuery(ctx context.Context, client gat.Client, query string) error { @@ -130,7 +130,7 @@ func (p *Pool) CallFunction(ctx context.Context, client gat.Client, payload *pro if err != nil { return err } - return c.CallFunction(client, payload) + return c.CallFunction(ctx, client, payload) } var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/pool/transaction/worker.go b/lib/gat/pool/transaction/worker.go index 8ac9b0107ffc7a44478e32cc3128c97c41fd5427..a441568be21e4f2d79c29897090160cd6cf7ac17 100644 --- a/lib/gat/pool/transaction/worker.go +++ b/lib/gat/pool/transaction/worker.go @@ -105,7 +105,9 @@ func (w *worker) HandleDescribe(ctx context.Context, c gat.Client, d *protocol.D defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } errch := make(chan error) @@ -129,7 +131,9 @@ func (w *worker) HandleExecute(ctx context.Context, c gat.Client, e *protocol.Ex defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } errch := make(chan error) @@ -153,7 +157,9 @@ func (w *worker) HandleFunction(ctx context.Context, c gat.Client, fn *protocol. defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } errch := make(chan error) @@ -177,7 +183,9 @@ func (w *worker) HandleSimpleQuery(ctx context.Context, c gat.Client, query stri defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } start := time.Now() @@ -208,7 +216,9 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri defer w.ret() if w.w.user.StatementTimeout != 0 { - ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + var done context.CancelFunc + ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond) + defer done() } start := time.Now() @@ -262,7 +272,7 @@ func (w *worker) z_actually_do_describe(ctx context.Context, client gat.Client, } w.setCurrentBinding(client, target) defer w.unsetCurrentBinding(client, target) - return target.Describe(client, payload) + return target.Describe(ctx, client, payload) } func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, payload *protocol.Execute) error { srv := w.chooseShard(client) @@ -302,7 +312,7 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p if target == nil { return fmt.Errorf("describe('%+v') fail: no server", payload) } - return target.Execute(client, payload) + return target.Execute(ctx, client, payload) } func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { srv := w.chooseShard(client) @@ -319,7 +329,7 @@ func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payloa } w.setCurrentBinding(client, target) defer w.unsetCurrentBinding(client, target) - err := target.CallFunction(client, payload) + err := target.CallFunction(ctx, client, payload) if err != nil { return fmt.Errorf("fn('%+v') fail: %w ", payload, err) }