diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 288bca2ee02760c383e58f3e032bc9cea56181ac..20ac2ca5b4ba95db943bb5fe13dca464001e8d89 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -450,7 +450,7 @@ func (c *Client) recvLoop() { } break } - //log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) + log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) c.recv <- recv } } diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 019f8a1378f60bdbe24401bc41f0508b4569d887..667991cd46ea46c14cfe4b1bbf7e975a676fa5ca 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "errors" "fmt" "net" "reflect" @@ -397,15 +398,17 @@ func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { } } - // test if prepared statement is the same - if prev, ok := s.boundPreparedStatments[name]; ok { - if reflect.DeepEqual(prev, stmt) { - // we don't need to bind, we're good - return nil - } + if name != "" { + // test if prepared statement is the same + if prev, ok := s.boundPreparedStatments[name]; ok { + if reflect.DeepEqual(prev, stmt) { + // we don't need to bind, we're good + return nil + } - // there is a statement bound that needs to be unbound - s.destructPreparedStatement(name) + // there is a statement bound that needs to be unbound + s.destructPreparedStatement(name) + } } s.boundPreparedStatments[name] = stmt @@ -516,6 +519,7 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { } func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { + log.Printf("execute `%s`", e.Fields.Name) err := s.ensurePortal(client, e.Fields.Name) if err != nil { return err @@ -536,10 +540,19 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) - switch pkt.(type) { + switch p := pkt.(type) { case *protocol.BindComplete, *protocol.ParseComplete: case *protocol.ReadyForQuery: - finish = true + if p.Fields.Status != 'I' { + err = errors.New("transactions are not allowed in statements") + + end := new(protocol.Query) + end.Fields.Query = "END" + _ = s.writePacket(end) + _ = s.flush() + } else { + finish = true + } default: forward = true } diff --git a/test/docker-compose.yml b/test/docker-compose.yml index d360903f567c79a0f654f3910ca68a6cbaa6cbb8..3b8a6e2373b75b90dcf01ce1c2a911ed667e0939 100644 --- a/test/docker-compose.yml +++ b/test/docker-compose.yml @@ -5,7 +5,7 @@ services: image: postgres restart: always environment: - POSTGRES_PASSWORD: exmaple + POSTGRES_PASSWORD: example ports: - 5432:5432 adminer: @@ -17,6 +17,10 @@ services: build: ../ restart: always environment: - PGGAT_DB_PASS: example + PSQL_DB_USER_RW: postgres + PSQL_DB_PASS_RW: example + PSQL_DB_USER_RO: postgres + PSQL_DB_PASS_RO: example + PSQL_PRI_DB_HOST: db ports: - 6432:6432