From d586233a4bb6bd05df420e482fcf8949ea4584f3 Mon Sep 17 00:00:00 2001
From: Garet Halliday <ghalliday@gfxlabs.io>
Date: Thu, 29 Sep 2022 17:36:45 -0500
Subject: [PATCH] don't allow clients to do bad stuff

---
 lib/gat/gatling/client/client.go |  2 +-
 lib/gat/gatling/server/server.go | 33 ++++++++++++++++++++++----------
 test/docker-compose.yml          |  8 ++++++--
 3 files changed, 30 insertions(+), 13 deletions(-)

diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go
index 288bca2e..20ac2ca5 100644
--- a/lib/gat/gatling/client/client.go
+++ b/lib/gat/gatling/client/client.go
@@ -450,7 +450,7 @@ func (c *Client) recvLoop() {
 			}
 			break
 		}
-		//log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv)
+		log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv)
 		c.recv <- recv
 	}
 }
diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go
index 019f8a13..667991cd 100644
--- a/lib/gat/gatling/server/server.go
+++ b/lib/gat/gatling/server/server.go
@@ -2,6 +2,7 @@ package server
 
 import (
 	"bufio"
+	"errors"
 	"fmt"
 	"net"
 	"reflect"
@@ -397,15 +398,17 @@ func (s *Server) ensurePreparedStatement(client gat.Client, name string) error {
 		}
 	}
 
-	// test if prepared statement is the same
-	if prev, ok := s.boundPreparedStatments[name]; ok {
-		if reflect.DeepEqual(prev, stmt) {
-			// we don't need to bind, we're good
-			return nil
-		}
+	if name != "" {
+		// test if prepared statement is the same
+		if prev, ok := s.boundPreparedStatments[name]; ok {
+			if reflect.DeepEqual(prev, stmt) {
+				// we don't need to bind, we're good
+				return nil
+			}
 
-		// there is a statement bound that needs to be unbound
-		s.destructPreparedStatement(name)
+			// there is a statement bound that needs to be unbound
+			s.destructPreparedStatement(name)
+		}
 	}
 
 	s.boundPreparedStatments[name] = stmt
@@ -516,6 +519,7 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error {
 }
 
 func (s *Server) Execute(client gat.Client, e *protocol.Execute) error {
+	log.Printf("execute `%s`", e.Fields.Name)
 	err := s.ensurePortal(client, e.Fields.Name)
 	if err != nil {
 		return err
@@ -536,10 +540,19 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error {
 
 	return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
 		//log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt)
-		switch pkt.(type) {
+		switch p := pkt.(type) {
 		case *protocol.BindComplete, *protocol.ParseComplete:
 		case *protocol.ReadyForQuery:
-			finish = true
+			if p.Fields.Status != 'I' {
+				err = errors.New("transactions are not allowed in statements")
+
+				end := new(protocol.Query)
+				end.Fields.Query = "END"
+				_ = s.writePacket(end)
+				_ = s.flush()
+			} else {
+				finish = true
+			}
 		default:
 			forward = true
 		}
diff --git a/test/docker-compose.yml b/test/docker-compose.yml
index d360903f..3b8a6e23 100644
--- a/test/docker-compose.yml
+++ b/test/docker-compose.yml
@@ -5,7 +5,7 @@ services:
     image: postgres
     restart: always
     environment:
-      POSTGRES_PASSWORD: exmaple
+      POSTGRES_PASSWORD: example
     ports:
       - 5432:5432
   adminer:
@@ -17,6 +17,10 @@ services:
     build: ../
     restart: always
     environment:
-      PGGAT_DB_PASS: example
+      PSQL_DB_USER_RW: postgres
+      PSQL_DB_PASS_RW: example
+      PSQL_DB_USER_RO: postgres
+      PSQL_DB_PASS_RO: example
+      PSQL_PRI_DB_HOST: db
     ports:
       - 6432:6432
-- 
GitLab