From 6763eaee398888a69a7b3efb72172f420fa3fd08 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 29 Aug 2023 21:00:37 -0500
Subject: [PATCH] fix eqp

---
 lib/gat/pool/pool.go                               |  7 +++++++
 lib/middleware/middlewares/eqp/client.go           | 14 +++++++-------
 lib/middleware/middlewares/eqp/portal.go           |  1 +
 .../middlewares/eqp/preparedStatement.go           |  1 +
 lib/middleware/middlewares/eqp/server.go           |  2 +-
 lib/middleware/middlewares/ps/client.go            |  2 +-
 lib/middleware/middlewares/ps/server.go            |  2 +-
 7 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index 1eafe352..f0c2ec3d 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -145,6 +145,13 @@ func (T *Pool) _scaleUpRecipe(name string) {
 		middlewares = append(middlewares, eqpServer)
 	}
 
+	if len(middlewares) > 0 {
+		server = interceptor.NewInterceptor(
+			server,
+			middlewares...,
+		)
+	}
+
 	T.servers[serverID] = &poolServer{
 		conn:   server,
 		accept: params,
diff --git a/lib/middleware/middlewares/eqp/client.go b/lib/middleware/middlewares/eqp/client.go
index 9d71b01f..3cfc9dd4 100644
--- a/lib/middleware/middlewares/eqp/client.go
+++ b/lib/middleware/middlewares/eqp/client.go
@@ -44,7 +44,7 @@ func (T *Client) Write(_ middleware.Context, packet zap.Packet) error {
 	case packets.TypeReadyForQuery:
 		var readyForQuery packets.ReadyForQuery
 		if !readyForQuery.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format a")
 		}
 		if readyForQuery == 'I' {
 			// clobber all named portals
@@ -70,7 +70,7 @@ func (T *Client) Read(ctx middleware.Context, packet zap.Packet) error {
 
 		destination, preparedStatement, ok := ReadParse(packet)
 		if !ok {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format b")
 		}
 
 		T.preparedStatements[destination] = preparedStatement
@@ -86,7 +86,7 @@ func (T *Client) Read(ctx middleware.Context, packet zap.Packet) error {
 
 		destination, portal, ok := ReadBind(packet)
 		if !ok {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format c")
 		}
 
 		T.portals[destination] = portal
@@ -102,7 +102,7 @@ func (T *Client) Read(ctx middleware.Context, packet zap.Packet) error {
 
 		var p packets.Close
 		if !p.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format d")
 		}
 		switch p.Which {
 		case 'S':
@@ -110,7 +110,7 @@ func (T *Client) Read(ctx middleware.Context, packet zap.Packet) error {
 		case 'P':
 			T.deletePortal(p.Target)
 		default:
-			return errors.New("bad packet format")
+			return errors.New("bad packet format e")
 		}
 
 		// send close complete
@@ -123,7 +123,7 @@ func (T *Client) Read(ctx middleware.Context, packet zap.Packet) error {
 		// ensure target exists
 		var describe packets.Describe
 		if !describe.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format f")
 		}
 		switch describe.Which {
 		case 'S', 'P':
@@ -134,7 +134,7 @@ func (T *Client) Read(ctx middleware.Context, packet zap.Packet) error {
 	case packets.TypeExecute:
 		var execute packets.Execute
 		if !execute.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format g")
 		}
 	}
 	return nil
diff --git a/lib/middleware/middlewares/eqp/portal.go b/lib/middleware/middlewares/eqp/portal.go
index 239423ec..43559895 100644
--- a/lib/middleware/middlewares/eqp/portal.go
+++ b/lib/middleware/middlewares/eqp/portal.go
@@ -22,5 +22,6 @@ func ReadBind(in zap.Packet) (destination string, portal Portal, ok bool) {
 
 	portal.packet = in
 	portal.hash = maphash.Bytes(seed, portal.packet.Payload())
+	ok = true
 	return
 }
diff --git a/lib/middleware/middlewares/eqp/preparedStatement.go b/lib/middleware/middlewares/eqp/preparedStatement.go
index bc64892f..e2498f47 100644
--- a/lib/middleware/middlewares/eqp/preparedStatement.go
+++ b/lib/middleware/middlewares/eqp/preparedStatement.go
@@ -21,5 +21,6 @@ func ReadParse(packet zap.Packet) (destination string, preparedStatement Prepare
 
 	preparedStatement.packet = packet
 	preparedStatement.hash = maphash.Bytes(seed, preparedStatement.packet.Payload())
+	ok = true
 	return
 }
diff --git a/lib/middleware/middlewares/eqp/server.go b/lib/middleware/middlewares/eqp/server.go
index 7d62472e..7cd44612 100644
--- a/lib/middleware/middlewares/eqp/server.go
+++ b/lib/middleware/middlewares/eqp/server.go
@@ -279,7 +279,7 @@ func (T *Server) Read(ctx middleware.Context, packet zap.Packet) error {
 	case packets.TypeReadyForQuery:
 		var state packets.ReadyForQuery
 		if !state.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format h")
 		}
 		if state == 'I' {
 			// clobber all portals
diff --git a/lib/middleware/middlewares/ps/client.go b/lib/middleware/middlewares/ps/client.go
index 18d13d71..de997b74 100644
--- a/lib/middleware/middlewares/ps/client.go
+++ b/lib/middleware/middlewares/ps/client.go
@@ -27,7 +27,7 @@ func (T *Client) Write(ctx middleware.Context, packet zap.Packet) error {
 	case packets.TypeParameterStatus:
 		var ps packets.ParameterStatus
 		if !ps.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format i")
 		}
 		ikey := strutil.MakeCIString(ps.Key)
 		if T.parameters[ikey] == ps.Value {
diff --git a/lib/middleware/middlewares/ps/server.go b/lib/middleware/middlewares/ps/server.go
index e4f652e5..4bf258cc 100644
--- a/lib/middleware/middlewares/ps/server.go
+++ b/lib/middleware/middlewares/ps/server.go
@@ -26,7 +26,7 @@ func (T *Server) Read(_ middleware.Context, packet zap.Packet) error {
 	case packets.TypeParameterStatus:
 		var ps packets.ParameterStatus
 		if !ps.ReadFromPacket(packet) {
-			return errors.New("bad packet format")
+			return errors.New("bad packet format j")
 		}
 		ikey := strutil.MakeCIString(ps.Key)
 		if T.parameters == nil {
-- 
GitLab