From d586233a4bb6bd05df420e482fcf8949ea4584f3 Mon Sep 17 00:00:00 2001 From: Garet Halliday <ghalliday@gfxlabs.io> Date: Thu, 29 Sep 2022 17:36:45 -0500 Subject: [PATCH] don't allow clients to do bad stuff --- lib/gat/gatling/client/client.go | 2 +- lib/gat/gatling/server/server.go | 33 ++++++++++++++++++++++---------- test/docker-compose.yml | 8 ++++++-- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 288bca2e..20ac2ca5 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -450,7 +450,7 @@ func (c *Client) recvLoop() { } break } - //log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) + log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv) c.recv <- recv } } diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 019f8a13..667991cd 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "errors" "fmt" "net" "reflect" @@ -397,15 +398,17 @@ func (s *Server) ensurePreparedStatement(client gat.Client, name string) error { } } - // test if prepared statement is the same - if prev, ok := s.boundPreparedStatments[name]; ok { - if reflect.DeepEqual(prev, stmt) { - // we don't need to bind, we're good - return nil - } + if name != "" { + // test if prepared statement is the same + if prev, ok := s.boundPreparedStatments[name]; ok { + if reflect.DeepEqual(prev, stmt) { + // we don't need to bind, we're good + return nil + } - // there is a statement bound that needs to be unbound - s.destructPreparedStatement(name) + // there is a statement bound that needs to be unbound + s.destructPreparedStatement(name) + } } s.boundPreparedStatments[name] = stmt @@ -516,6 +519,7 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { } func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { + log.Printf("execute `%s`", e.Fields.Name) err := s.ensurePortal(client, e.Fields.Name) if err != nil { return err @@ -536,10 +540,19 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) - switch pkt.(type) { + switch p := pkt.(type) { case *protocol.BindComplete, *protocol.ParseComplete: case *protocol.ReadyForQuery: - finish = true + if p.Fields.Status != 'I' { + err = errors.New("transactions are not allowed in statements") + + end := new(protocol.Query) + end.Fields.Query = "END" + _ = s.writePacket(end) + _ = s.flush() + } else { + finish = true + } default: forward = true } diff --git a/test/docker-compose.yml b/test/docker-compose.yml index d360903f..3b8a6e23 100644 --- a/test/docker-compose.yml +++ b/test/docker-compose.yml @@ -5,7 +5,7 @@ services: image: postgres restart: always environment: - POSTGRES_PASSWORD: exmaple + POSTGRES_PASSWORD: example ports: - 5432:5432 adminer: @@ -17,6 +17,10 @@ services: build: ../ restart: always environment: - PGGAT_DB_PASS: example + PSQL_DB_USER_RW: postgres + PSQL_DB_PASS_RW: example + PSQL_DB_USER_RO: postgres + PSQL_DB_PASS_RO: example + PSQL_PRI_DB_HOST: db ports: - 6432:6432 -- GitLab