From c00131dbd2dc0a38697a901960efc89362287adc Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Thu, 14 Sep 2023 16:19:36 -0500
Subject: [PATCH] fix eqp copy from stdin

---
 lib/bouncer/backends/v0/query.go       | 36 +++++++++++++++++++-------
 lib/middleware/middlewares/eqp/sync.go |  3 ++-
 test/tester_test.go                    |  1 +
 test/tests/copy_in.go                  |  1 +
 4 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go
index 89df7fbb..44e57a79 100644
--- a/lib/bouncer/backends/v0/query.go
+++ b/lib/bouncer/backends/v0/query.go
@@ -142,15 +142,15 @@ func FunctionCall(ctx *Context, server fed.ReadWriter, packet fed.Packet) error
 	}
 }
 
-func Sync(ctx *Context, server fed.ReadWriter) error {
+func Sync(ctx *Context, server fed.ReadWriter) (bool, error) {
 	if err := server.WritePacket(fed.NewPacket(packets.TypeSync)); err != nil {
-		return err
+		return false, err
 	}
 
 	for {
 		packet, err := server.ReadPacket(true)
 		if err != nil {
-			return err
+			return false, err
 		}
 
 		switch packet.Type() {
@@ -173,22 +173,24 @@ func Sync(ctx *Context, server fed.ReadWriter) error {
 			ctx.PeerWrite(packet)
 		case packets.TypeCopyInResponse:
 			if err = CopyIn(ctx, server, packet); err != nil {
-				return err
+				return false, err
 			}
+			// why
+			return false, nil
 		case packets.TypeCopyOutResponse:
 			if err = CopyOut(ctx, server, packet); err != nil {
-				return err
+				return false, err
 			}
 		case packets.TypeReadyForQuery:
 			var txState packets.ReadyForQuery
 			if !txState.ReadFromPacket(packet) {
-				return ErrBadFormat
+				return false, ErrBadFormat
 			}
 			ctx.TxState = byte(txState)
 			ctx.PeerWrite(packet)
-			return nil
+			return true, nil
 		default:
-			return ErrUnexpectedPacket
+			return false, ErrUnexpectedPacket
 		}
 	}
 }
@@ -201,12 +203,26 @@ func EQP(ctx *Context, server fed.ReadWriter, packet fed.Packet) error {
 	for {
 		packet = ctx.PeerRead()
 		if packet == nil {
-			return Sync(ctx, server)
+			for {
+				ok, err := Sync(ctx, server)
+				if err != nil {
+					return err
+				}
+				if ok {
+					return nil
+				}
+			}
 		}
 
 		switch packet.Type() {
 		case packets.TypeSync:
-			return Sync(ctx, server)
+			ok, err := Sync(ctx, server)
+			if err != nil {
+				return err
+			}
+			if ok {
+				return nil
+			}
 		case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush:
 			if err := server.WritePacket(packet); err != nil {
 				return err
diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go
index 848ec954..57b81839 100644
--- a/lib/middleware/middlewares/eqp/sync.go
+++ b/lib/middleware/middlewares/eqp/sync.go
@@ -64,5 +64,6 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error {
 		}
 	}
 
-	return backends.Sync(new(backends.Context), server)
+	_, err := backends.Sync(new(backends.Context), server)
+	return err
 }
diff --git a/test/tester_test.go b/test/tester_test.go
index a1001767..0365a5bc 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -5,6 +5,7 @@ import (
 	"encoding/hex"
 	"fmt"
 	"net"
+	_ "net/http/pprof"
 	"strconv"
 	"testing"
 
diff --git a/test/tests/copy_in.go b/test/tests/copy_in.go
index 18f85261..1bbe5bec 100644
--- a/test/tests/copy_in.go
+++ b/test/tests/copy_in.go
@@ -34,6 +34,7 @@ var CopyIn1 = test.Test{
 		inst.CopyData{49, 50, 51, 9, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 10},
 		inst.CopyData{45, 51, 50, 52, 9, 103, 97, 114, 101, 116, 32, 119, 97, 115, 32, 104, 101, 114, 101, 10},
 		inst.CopyDone{},
+		inst.Sync{},
 		inst.SimpleQuery("DROP TABLE test"),
 	},
 }
-- 
GitLab