From ed08ed7b2b28de3112226df5807331642c8eab9e Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 12 Sep 2023 18:10:04 -0500
Subject: [PATCH] just need result comparison

---
 lib/gsql/client.go                      | 30 ++++++++++++-----
 lib/middleware/middlewares/ps/client.go |  2 +-
 lib/util/maps/clone.go                  | 14 ++++++++
 test/runner.go                          | 44 ++++++++++++++++++++++++-
 test/test.go                            |  1 -
 test/tests/simple_query.go              |  1 -
 6 files changed, 79 insertions(+), 13 deletions(-)
 create mode 100644 lib/util/maps/clone.go

diff --git a/lib/gsql/client.go b/lib/gsql/client.go
index ff9a15e6..436acd50 100644
--- a/lib/gsql/client.go
+++ b/lib/gsql/client.go
@@ -41,6 +41,22 @@ func (T *Client) Do(result ResultWriter, packets ...fed.Packet) {
 	}
 }
 
+func (T *Client) queueNext() bool {
+	b, ok := T.queue.PopFront()
+	if ok {
+		for _, packet := range b.packets {
+			T.read.PushBack(packet)
+		}
+		T.write = b.result
+		if T.writeC != nil {
+			T.writeC.Broadcast()
+		}
+		return true
+	}
+
+	return false
+}
+
 func (T *Client) ReadPacket(typed bool) (fed.Packet, error) {
 	T.mu.Lock()
 	defer T.mu.Unlock()
@@ -54,15 +70,7 @@ func (T *Client) ReadPacket(typed bool) (fed.Packet, error) {
 		}
 
 		// try to add next in queue
-		b, ok := T.queue.PopFront()
-		if ok {
-			for _, packet := range b.packets {
-				T.read.PushBack(packet)
-			}
-			T.write = b.result
-			if T.writeC != nil {
-				T.writeC.Broadcast()
-			}
+		if T.queueNext() {
 			continue
 		}
 
@@ -88,6 +96,10 @@ func (T *Client) WritePacket(packet fed.Packet) error {
 	defer T.mu.Unlock()
 
 	for T.write == nil {
+		if T.read.Length() == 0 && T.queueNext() {
+			continue
+		}
+
 		if T.closed {
 			return io.EOF
 		}
diff --git a/lib/middleware/middlewares/ps/client.go b/lib/middleware/middlewares/ps/client.go
index bdc92e77..e3970c8f 100644
--- a/lib/middleware/middlewares/ps/client.go
+++ b/lib/middleware/middlewares/ps/client.go
@@ -30,7 +30,7 @@ func (T *Client) Write(ctx middleware.Context, packet fed.Packet) error {
 			return errors.New("bad packet format i")
 		}
 		ikey := strutil.MakeCIString(ps.Key)
-		if T.parameters[ikey] == ps.Value {
+		if T.synced && T.parameters[ikey] == ps.Value {
 			// already set
 			ctx.Cancel()
 			break
diff --git a/lib/util/maps/clone.go b/lib/util/maps/clone.go
new file mode 100644
index 00000000..4b6c6162
--- /dev/null
+++ b/lib/util/maps/clone.go
@@ -0,0 +1,14 @@
+package maps
+
+func Clone[K comparable, V any](m map[K]V) map[K]V {
+	if m == nil {
+		return nil
+	}
+
+	m2 := make(map[K]V, len(m))
+	for k, v := range m {
+		m2[k] = v
+	}
+
+	return m2
+}
diff --git a/test/runner.go b/test/runner.go
index 74e234f9..f61cb293 100644
--- a/test/runner.go
+++ b/test/runner.go
@@ -6,6 +6,7 @@ import (
 
 	"tuxpa.in/a/zlog/log"
 
+	"pggat/lib/bouncer/bouncers/v2"
 	"pggat/lib/fed"
 	packets "pggat/lib/fed/packets/v3.0"
 	"pggat/lib/gat/pool"
@@ -37,7 +38,10 @@ func (T *Runner) setup() error {
 	}
 
 	for name, options := range T.config.Modes {
-		p := pool.NewPool(options)
+		opts := options
+		// allowing ps sync would mess up testing
+		opts.ParameterStatusSync = pool.ParameterStatusSyncNone
+		p := pool.NewPool(opts)
 		p.AddRecipe("server", recipe.NewRecipe(
 			recipe.Options{
 				Dialer: T.config.Peer,
@@ -57,7 +61,45 @@ func (logWriter) WritePacket(pkt fed.Packet) error {
 }
 
 func (T *Runner) run(pkts ...fed.Packet) error {
+	// expected
+	{
+		log.Print("expected packets")
+
+		var client gsql.Client
+		client.Do(logWriter{}, pkts...)
+		if err := client.Close(); err != nil {
+			return err
+		}
+
+		server, _, err := T.config.Peer.Dial()
+		if err != nil {
+			return err
+		}
+
+		for {
+			p, err := client.ReadPacket(true)
+			if err != nil {
+				if errors.Is(err, io.EOF) {
+					break
+				}
+				return err
+			}
+
+			clientErr, serverErr := bouncers.Bounce(&client, server, p)
+			if clientErr != nil {
+				return clientErr
+			}
+			if serverErr != nil {
+				return serverErr
+			}
+		}
+	}
+
+	// actual
 	for name, p := range T.pools {
+		log.Print()
+		log.Print("pool ", name)
+
 		var client gsql.Client
 		client.Do(logWriter{}, pkts...)
 		if err := client.Close(); err != nil {
diff --git a/test/test.go b/test/test.go
index 8ef87487..49b095d0 100644
--- a/test/test.go
+++ b/test/test.go
@@ -3,6 +3,5 @@ package test
 import "pggat/test/inst"
 
 type Test struct {
-	Parallel     bool
 	Instructions []inst.Instruction
 }
diff --git a/test/tests/simple_query.go b/test/tests/simple_query.go
index d1697c19..638ec681 100644
--- a/test/tests/simple_query.go
+++ b/test/tests/simple_query.go
@@ -6,7 +6,6 @@ import (
 )
 
 var SimpleQuery = test.Test{
-	Parallel: true,
 	Instructions: []inst.Instruction{
 		inst.SimpleQuery("select 1;"),
 	},
-- 
GitLab