From 3409833ae03074d8e455c76c7fbd517140c3755c Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 12 Sep 2023 18:18:36 -0500
Subject: [PATCH] gaming

---
 test/capturer.go    | 37 +++++++++++++++++++++++++++++++++++++
 test/runner.go      | 23 ++++++++---------------
 test/tester_test.go |  4 +++-
 3 files changed, 48 insertions(+), 16 deletions(-)
 create mode 100644 test/capturer.go

diff --git a/test/capturer.go b/test/capturer.go
new file mode 100644
index 00000000..5b6dfeb8
--- /dev/null
+++ b/test/capturer.go
@@ -0,0 +1,37 @@
+package test
+
+import (
+	"bytes"
+	"fmt"
+
+	"pggat/lib/fed"
+	"pggat/lib/gsql"
+)
+
+type Capturer struct {
+	Packets []fed.Packet
+}
+
+func (T *Capturer) WritePacket(packet fed.Packet) error {
+	T.Packets = append(T.Packets, packet)
+	return nil
+}
+
+func (T *Capturer) Check(other *Capturer) error {
+	if len(T.Packets) != len(other.Packets) {
+		return fmt.Errorf("not enough packets! got %d but expected %d", len(other.Packets), len(T.Packets))
+	}
+
+	for i := range T.Packets {
+		expected := T.Packets[i]
+		actual := other.Packets[i]
+
+		if !bytes.Equal(expected.Bytes(), actual.Bytes()) {
+			return fmt.Errorf("mismatched packet! expected %v but got %v", expected.Bytes(), actual.Bytes())
+		}
+	}
+
+	return nil
+}
+
+var _ gsql.ResultWriter = (*Capturer)(nil)
diff --git a/test/runner.go b/test/runner.go
index f61cb293..4cd235c5 100644
--- a/test/runner.go
+++ b/test/runner.go
@@ -4,8 +4,6 @@ import (
 	"errors"
 	"io"
 
-	"tuxpa.in/a/zlog/log"
-
 	"pggat/lib/bouncer/bouncers/v2"
 	"pggat/lib/fed"
 	packets "pggat/lib/fed/packets/v3.0"
@@ -53,20 +51,13 @@ func (T *Runner) setup() error {
 	return nil
 }
 
-type logWriter struct{}
-
-func (logWriter) WritePacket(pkt fed.Packet) error {
-	log.Print("got packet ", pkt)
-	return nil
-}
-
 func (T *Runner) run(pkts ...fed.Packet) error {
 	// expected
-	{
-		log.Print("expected packets")
+	var expected Capturer
 
+	{
 		var client gsql.Client
-		client.Do(logWriter{}, pkts...)
+		client.Do(&expected, pkts...)
 		if err := client.Close(); err != nil {
 			return err
 		}
@@ -97,11 +88,10 @@ func (T *Runner) run(pkts ...fed.Packet) error {
 
 	// actual
 	for name, p := range T.pools {
-		log.Print()
-		log.Print("pool ", name)
+		var result Capturer
 
 		var client gsql.Client
-		client.Do(logWriter{}, pkts...)
+		client.Do(&result, pkts...)
 		if err := client.Close(); err != nil {
 			return err
 		}
@@ -110,6 +100,9 @@ func (T *Runner) run(pkts ...fed.Packet) error {
 			return err
 		}
 
+		if err := expected.Check(&result); err != nil {
+			return err
+		}
 		_ = name
 	}
 
diff --git a/test/tester_test.go b/test/tester_test.go
index 9f74ebbf..809190fc 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -33,7 +33,9 @@ func TestTester(t *testing.T) {
 			},
 		},
 	})
-	if err := tester.Run(tests.SimpleQuery); err != nil {
+	if err := tester.Run(
+		tests.SimpleQuery,
+	); err != nil {
 		t.Error(err)
 	}
 }
-- 
GitLab