diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index bb99ea6c7a868b8e18a0319201605f26b668d06f..e8447eb6ce18d6fb5b86e46c499a70c8b96f5e90 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -2,11 +2,9 @@ package main import ( "context" - "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat/gatling" "git.tuxpa.in/a/zlog/log" - "net/http" _ "net/http/pprof" ) @@ -17,6 +15,7 @@ const CONFIG = "./config_data.yml" func main() { //zlog.SetGlobalLevel(zlog.PanicLevel) go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) }() diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index d398dcc670aaf0eff4fb306bf21036c3e23ac8d6..d7289af7649ab74b433e6a6f6c783dae13cd99c0 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -60,6 +60,7 @@ type Client struct { statements map[string]*protocol.Parse portals map[string]*protocol.Bind conf *config.Global + state rune log zlog.Logger } @@ -80,6 +81,7 @@ func NewClient( gatling: gatling, statements: make(map[string]*protocol.Parse), portals: make(map[string]*protocol.Bind), + state: 'I', conf: conf, } c.log = log.With(). @@ -282,8 +284,22 @@ func (c *Client) Accept(ctx context.Context) error { open := true for open { open, err = c.tick(ctx) + if !open { + break + } if err != nil { - return err + err = c.Send(pg_error.IntoPacket(err)) + if err != nil { + return err + } + } + if c.state == 'I' { + rq := new(protocol.ReadyForQuery) + rq.Fields.Status = 'I' + err = c.Send(rq) + if err != nil { + return err + } } } return nil @@ -327,9 +343,8 @@ func (c *Client) tick(ctx context.Context) (bool, error) { case *protocol.Execute: return true, c.handle_execute(ctx, cast) case *protocol.Sync: - pkt := new(protocol.ReadyForQuery) - pkt.Fields.Status = 'I' - return true, c.Send(pkt) + c.state = 'I' + return true, nil case *protocol.Query: return true, c.handle_query(ctx, cast) case *protocol.FunctionCall: @@ -344,19 +359,25 @@ func (c *Client) tick(ctx context.Context) (bool, error) { func (c *Client) parse(ctx context.Context, q *protocol.Parse) error { c.statements[q.Fields.PreparedStatement] = q + c.state = 'T' return c.Send(new(protocol.ParseComplete)) } func (c *Client) bind(ctx context.Context, b *protocol.Bind) error { c.portals[b.Fields.Destination] = b + c.state = 'T' return c.Send(new(protocol.BindComplete)) } func (c *Client) handle_describe(ctx context.Context, d *protocol.Describe) error { + //log.Println("describe") + c.state = 'T' return c.server.Describe(ctx, c, d) } func (c *Client) handle_execute(ctx context.Context, e *protocol.Execute) error { + //log.Println("execute") + c.state = 'T' return c.server.Execute(ctx, c, e) } @@ -431,6 +452,7 @@ func (c *Client) handle_simple_query(ctx context.Context, q string) error { } func (c *Client) handle_transaction(ctx context.Context, q string) error { + //log.Println("transaction:", q) return c.server.Transaction(ctx, c, q) } diff --git a/lib/gat/gatling/conn_pool/server/server.go b/lib/gat/gatling/conn_pool/server/server.go index 8dc7653692d464ca0d1676f4495ed26d5f5ac74c..747561747e0dbec12f4ab68595baa3be790f2dd3 100644 --- a/lib/gat/gatling/conn_pool/server/server.go +++ b/lib/gat/gatling/conn_pool/server/server.go @@ -6,7 +6,6 @@ import ( "io" "net" "reflect" - "strings" "time" "gfx.cafe/gfx/pggat/lib/gat" @@ -139,7 +138,7 @@ func (s *Server) connect(ctx context.Context) error { var sm sasl.StateMachine for { var pkt protocol.Packet - pkt, err = protocol.ReadBackend(s.r) + pkt, err = s.readPacket() if err != nil { return err } @@ -218,28 +217,34 @@ func (s *Server) connect(ctx context.Context) error { } } -func (s *Server) forwardTo(client gat.Client, predicate func(pkt protocol.Packet) (forward bool, finish bool)) error { +func (s *Server) forwardTo(client gat.Client, predicate func(pkt protocol.Packet) (forward bool, finish bool, err error)) error { + var e error for { var rsp protocol.Packet - rsp, err := protocol.ReadBackend(s.r) + rsp, err := s.readPacket() if err != nil { return err } - forward, finish := predicate(rsp) - if forward { - err = client.Send(rsp) - if err != nil { - return err - } + //log.Printf("backend packet(%s) %+v", reflect.TypeOf(rsp), rsp) + var forward, finish bool + forward, finish, e = predicate(rsp) + if forward && e == nil { + e = client.Send(rsp) } if finish { - return nil + return e } } } -func (s *Server) writePacket(pkt protocol.Packet) error { +func (s *Server) writeNoFlush(pkt protocol.Packet) error { + //log.Printf("send backend packet(%s) %+v", reflect.TypeOf(pkt), pkt) _, err := pkt.Write(s.bufwr) + return err +} + +func (s *Server) writePacket(pkt protocol.Packet) error { + err := s.writeNoFlush(pkt) if err != nil { s.bufwr.Reset(s.wr) return err @@ -247,6 +252,10 @@ func (s *Server) writePacket(pkt protocol.Packet) error { return s.bufwr.Flush() } +func (s *Server) readPacket() (protocol.Packet, error) { + return protocol.ReadBackend(s.r) +} + func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { // send prepared statement stmt := client.GetPreparedStatement(name) @@ -272,7 +281,7 @@ func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { s.bound_prepared_statments[name] = stmt // send prepared statement to server - _, _ = stmt.Write(s.bufwr) + _ = s.writeNoFlush(stmt) return nil } @@ -292,26 +301,31 @@ func (s *Server) ensurePortal(client gat.Client, name string) error { return err } - if prev, ok := s.bound_portals[name]; ok { - if reflect.DeepEqual(prev, portal) { - return nil + if name != "" { + if prev, ok := s.bound_portals[name]; ok { + if reflect.DeepEqual(prev, portal) { + return nil + } } } s.bound_portals[name] = portal - _, _ = portal.Write(s.bufwr) + _ = s.writeNoFlush(portal) return nil } func (s *Server) destructPreparedStatement(name string) { + if name == "" { + return + } delete(s.bound_prepared_statments, name) query := new(protocol.Query) query.Fields.Query = fmt.Sprintf("DEALLOCATE \"%s\"", name) _ = s.writePacket(query) // await server ready for { - r, _ := protocol.ReadBackend(s.r) + r, _ := s.readPacket() if _, ok := r.(*protocol.ReadyForQuery); ok { return } @@ -348,21 +362,22 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { } // now we actually execute the thing the client wants - _, _ = d.Write(s.bufwr) + _ = s.writeNoFlush(d) err := s.writePacket(new(protocol.Sync)) if err != nil { return err } - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { + //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) switch pkt.(type) { case *protocol.BindComplete, *protocol.ParseComplete: - return false, false case *protocol.ReadyForQuery: - return false, true + finish = true default: - return true, false + forward = true } + return }) } @@ -372,19 +387,22 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { return err } - _, _ = e.Write(s.bufwr) + _ = s.writeNoFlush(e) err = s.writePacket(new(protocol.Sync)) if err != nil { return err } - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { + //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) switch pkt.(type) { + case *protocol.BindComplete, *protocol.ParseComplete: case *protocol.ReadyForQuery: - return false, true + finish = true default: - return true, false + forward = true } + return }) } @@ -396,39 +414,24 @@ func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query strin if err != nil { return err } - if strings.Contains(query, "pg_sleep") { - go func() { - time.Sleep(1 * time.Second) - log.Println("cancel: ", s.Cancel()) - }() - } + // this function seems wild but it has to be the way it is so we read the whole response, even if the // client fails midway // read responses - e := s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Printf("forwarding pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) switch pkt.(type) { case *protocol.ReadyForQuery: // all ReadyForQuery packets end a simple query, regardless of type - return err == nil, true + finish = true case *protocol.CopyInResponse: - err = client.Send(pkt) - if err != nil { - return false, false - } + _ = client.Send(pkt) err = s.CopyIn(ctx, client) - if err != nil { - return false, false - } - return false, false default: - return err == nil, false + forward = true } + return }) - if e != nil { - return e - } - return err } func (s *Server) Transaction(ctx context.Context, client gat.Client, query string) error { @@ -438,16 +441,14 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin if err != nil { return err } - e := s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) switch p := pkt.(type) { case *protocol.ReadyForQuery: // all ReadyForQuery packets end a simple query, regardless of type if p.Fields.Status != 'I' { // send to client and wait for next query - if err == nil { - err = client.Send(pkt) - } + err = client.Send(pkt) if err == nil { select { @@ -470,26 +471,17 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin end.Fields.Query = "END;" _ = s.writePacket(end) } + } else { + finish = true } - return p.Fields.Status == 'I', p.Fields.Status == 'I' case *protocol.CopyInResponse: - err = client.Send(pkt) - if err != nil { - return false, false - } + _ = client.Send(pkt) err = s.CopyIn(ctx, client) - if err != nil { - return false, false - } - return false, false default: - return err == nil, false + forward = true } + return }) - if e != nil { - return e - } - return err } func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { @@ -525,13 +517,14 @@ func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) return err } // read responses - return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool) { + return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { switch pkt.(type) { case *protocol.ReadyForQuery: // status 'I' should only be encountered here - return true, true + finish = true default: - return true, false + forward = true } + return }) } diff --git a/lib/gat/gatling/gatling.go b/lib/gat/gatling/gatling.go index a2e2795396cc60cb8acb39a72bf59f37a47f9824..d4b62f4df95a65d9030c577119ff125772a7f817 100644 --- a/lib/gat/gatling/gatling.go +++ b/lib/gat/gatling/gatling.go @@ -124,17 +124,7 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { err := cl.Accept(ctx) if err != nil { log.Println("err in connection:", err.Error()) - switch e := err.(type) { - case *pg_error.Error: - _ = cl.Send(e.Packet()) - default: - pgErr := &pg_error.Error{ - Severity: pg_error.Err, - Code: pg_error.InternalError, - Message: e.Error(), - } - _ = cl.Send(pgErr.Packet()) - } + _ = cl.Send(pg_error.IntoPacket(err)) } _ = c.Close() return nil diff --git a/lib/gat/protocol/pg_error/error.go b/lib/gat/protocol/pg_error/error.go index 598d4a3d701ef37ecd7746bf0795aed9e9cc56f8..e3a7647b3a609ca20239c37f106b9a9d11f93633 100644 --- a/lib/gat/protocol/pg_error/error.go +++ b/lib/gat/protocol/pg_error/error.go @@ -451,3 +451,17 @@ func (E *Error) Packet() *protocol.ErrorResponse { func (E *Error) Error() string { return fmt.Sprintf("%s: %s", E.Severity, E.Message) } + +func IntoPacket(err error) *protocol.ErrorResponse { + switch e := err.(type) { + case *Error: + return e.Packet() + default: + er := Error{ + Severity: Err, + Code: InternalError, + Message: e.Error(), + } + return er.Packet() + } +}