From 46655c8fa8c0977736961189ed59a457c7de1407 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Thu, 12 Oct 2023 15:46:20 -0500
Subject: [PATCH] ready

---
 test/capturer.go |  37 -------------
 test/runner.go   | 135 +++++++++++++++++++++++++----------------------
 2 files changed, 71 insertions(+), 101 deletions(-)
 delete mode 100644 test/capturer.go

diff --git a/test/capturer.go b/test/capturer.go
deleted file mode 100644
index 3bdc1e3a..00000000
--- a/test/capturer.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package test
-
-import (
-	"bytes"
-	"fmt"
-
-	"gfx.cafe/gfx/pggat/lib/fed"
-	"gfx.cafe/gfx/pggat/lib/gsql"
-)
-
-type Capturer struct {
-	Packets []fed.Packet
-}
-
-func (T *Capturer) WritePacket(packet fed.Packet) error {
-	T.Packets = append(T.Packets, bytes.Clone(packet))
-	return nil
-}
-
-func (T *Capturer) Check(other *Capturer) error {
-	if len(T.Packets) != len(other.Packets) {
-		return fmt.Errorf("wrong number of packets! got %d but expected %d", len(other.Packets), len(T.Packets))
-	}
-
-	for i := range T.Packets {
-		expected := T.Packets[i]
-		actual := other.Packets[i]
-
-		if !bytes.Equal(expected.Bytes(), actual.Bytes()) {
-			return fmt.Errorf("mismatched packet! expected %v but got %v", expected.Bytes(), actual.Bytes())
-		}
-	}
-
-	return nil
-}
-
-var _ gsql.ResultWriter = (*Capturer)(nil)
diff --git a/test/runner.go b/test/runner.go
index d23e94c6..c1f24ccc 100644
--- a/test/runner.go
+++ b/test/runner.go
@@ -1,12 +1,14 @@
 package test
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"io"
 
 	"gfx.cafe/gfx/pggat/lib/bouncer/bouncers/v2"
 	"gfx.cafe/gfx/pggat/lib/fed"
+	"gfx.cafe/gfx/pggat/lib/fed/middlewares/unterminate"
 	packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
 	"gfx.cafe/gfx/pggat/lib/gat/pool/recipe"
 	"gfx.cafe/gfx/pggat/lib/gsql"
@@ -26,67 +28,91 @@ func MakeRunner(config Config, test Test) Runner {
 	}
 }
 
-func (T *Runner) prepare(client *gsql.Client, until int) []Capturer {
-	results := make([]Capturer, until)
-
+func (T *Runner) prepare(client *fed.Conn, until int) error {
 	for i := 0; i < until; i++ {
 		x := T.test.Instructions[i]
 		switch v := x.(type) {
 		case inst.SimpleQuery:
 			q := packets.Query(v)
-			client.Do(&results[i], q.IntoPacket(nil))
+			if err := client.WritePacket(&q); err != nil {
+				return err
+			}
 		case inst.Sync:
-			client.Do(&results[i], fed.NewPacket(packets.TypeSync))
+			if err := client.WritePacket(&packets.Sync{}); err != nil {
+				return err
+			}
 		case inst.Parse:
 			p := packets.Parse{
 				Destination: v.Destination,
 				Query:       v.Query,
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
+			if err := client.WritePacket(&p); err != nil {
+				return err
+			}
 		case inst.Bind:
 			p := packets.Bind{
 				Destination: v.Destination,
 				Source:      v.Source,
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
+			if err := client.WritePacket(&p); err != nil {
+				return err
+			}
 		case inst.DescribePortal:
 			p := packets.Describe{
-				Which:  'P',
-				Target: string(v),
+				Which: 'P',
+				Name:  string(v),
+			}
+			if err := client.WritePacket(&p); err != nil {
+				return err
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
 		case inst.DescribePreparedStatement:
 			p := packets.Describe{
-				Which:  'S',
-				Target: string(v),
+				Which: 'S',
+				Name:  string(v),
+			}
+			if err := client.WritePacket(&p); err != nil {
+				return err
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
 		case inst.Execute:
 			p := packets.Execute{
 				Target: string(v),
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
+			if err := client.WritePacket(&p); err != nil {
+				return err
+			}
 		case inst.ClosePortal:
 			p := packets.Close{
-				Which:  'P',
-				Target: string(v),
+				Which: 'P',
+				Name:  string(v),
+			}
+			if err := client.WritePacket(&p); err != nil {
+				return err
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
 		case inst.ClosePreparedStatement:
 			p := packets.Close{
-				Which:  'S',
-				Target: string(v),
+				Which: 'S',
+				Name:  string(v),
+			}
+			if err := client.WritePacket(&p); err != nil {
+				return err
 			}
-			client.Do(&results[i], p.IntoPacket(nil))
 		case inst.CopyData:
 			p := packets.CopyData(v)
-			client.Do(&results[i], p.IntoPacket(nil))
+			if err := client.WritePacket(&p); err != nil {
+				return err
+			}
 		case inst.CopyDone:
-			client.Do(&results[i], fed.NewPacket(packets.TypeCopyDone))
+			if err := client.WritePacket(&packets.CopyDone{}); err != nil {
+				return err
+			}
 		}
 	}
 
-	return results
+	if err := client.WritePacket(&packets.Terminate{}); err != nil {
+		return err
+	}
+
+	return client.Flush()
 }
 
 func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error {
@@ -98,9 +124,11 @@ func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error {
 		_ = server.Close()
 	}()
 
+	client.Middleware = append(client.Middleware, unterminate.Unterminate)
+
 	for {
 		var p fed.Packet
-		p, err = client.ReadPacket(true, p)
+		p, err = client.ReadPacket(true)
 		if err != nil {
 			if errors.Is(err, io.EOF) {
 				break
@@ -108,7 +136,7 @@ func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error {
 			return err
 		}
 
-		_, clientErr, serverErr := bouncers.Bounce(client, server, p)
+		clientErr, serverErr := bouncers.Bounce(client, server, p)
 		if clientErr != nil {
 			return clientErr
 		}
@@ -120,29 +148,31 @@ func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error {
 	return nil
 }
 
-func (T *Runner) runModeOnce(dialer recipe.Dialer) ([]Capturer, error) {
-	var client gsql.Client
-	results := T.prepare(&client, len(T.test.Instructions))
-	if err := client.Close(); err != nil {
+func (T *Runner) runModeOnce(dialer recipe.Dialer) ([]byte, error) {
+	inward, outward := gsql.NewPair()
+	if err := T.prepare(inward, len(T.test.Instructions)); err != nil {
+		return nil, err
+	}
+
+	if err := T.runModeL1(dialer, outward); err != nil {
 		return nil, err
 	}
 
-	if err := T.runModeL1(dialer, fed.NewConn(&client)); err != nil {
+	if err := inward.Close(); err != nil {
 		return nil, err
 	}
 
-	return results, nil
+	return io.ReadAll(inward.NetConn)
 }
 
 func (T *Runner) runModeFail(dialer recipe.Dialer) error {
 	for i := 1; i < len(T.test.Instructions)+1; i++ {
-		var client gsql.Client
-		T.prepare(&client, i)
-		if err := client.Close(); err != nil {
+		inward, outward := gsql.NewPair()
+		if err := T.prepare(inward, i); err != nil {
 			return err
 		}
 
-		if err := T.runModeL1(dialer, fed.NewConn(&client)); err != nil && !errors.Is(err, io.EOF) {
+		if err := T.runModeL1(dialer, outward); err != nil && !errors.Is(err, io.EOF) {
 			return err
 		}
 	}
@@ -150,7 +180,7 @@ func (T *Runner) runModeFail(dialer recipe.Dialer) error {
 	return nil
 }
 
-func (T *Runner) runMode(dialer recipe.Dialer) ([]Capturer, error) {
+func (T *Runner) runMode(dialer recipe.Dialer) ([]byte, error) {
 	instances := T.config.Stress
 	if instances < 1 || T.test.SideEffects {
 		return T.runModeOnce(dialer)
@@ -175,14 +205,8 @@ func (T *Runner) runMode(dialer recipe.Dialer) ([]Capturer, error) {
 			if err != nil {
 				return err
 			}
-			if len(expected) != len(actual) {
-				return fmt.Errorf("wrong number of results! expected %d but got %d", len(expected), len(actual))
-			}
-			for i, exp := range expected {
-				act := actual[i]
-				if err = exp.Check(&act); err != nil {
-					return err
-				}
+			if !bytes.Equal(expected, actual) {
+				return fmt.Errorf("mismatched results: expected %v but got %v", expected, actual)
 			}
 			return nil
 		})
@@ -198,7 +222,7 @@ func (T *Runner) runMode(dialer recipe.Dialer) ([]Capturer, error) {
 func (T *Runner) Run() error {
 	var errs []error
 
-	var expected []Capturer
+	var expected []byte
 
 	// modes
 	for name, mode := range T.config.Modes {
@@ -216,30 +240,13 @@ func (T *Runner) Run() error {
 			continue
 		}
 
-		if len(expected) != len(actual) {
+		if !bytes.Equal(expected, actual) {
 			errs = append(errs, ErrorIn{
 				Name: name,
-				Err:  fmt.Errorf("wrong number of results! expected %d but got %d", len(expected), len(actual)),
+				Err:  fmt.Errorf("mismatched results: expected %v but got %v", expected, actual),
 			})
 			continue
 		}
-
-		var modeErrs []error
-
-		for i, exp := range expected {
-			act := actual[i]
-
-			if err = exp.Check(&act); err != nil {
-				modeErrs = append(modeErrs, fmt.Errorf("instruction %d: %s", i+1, err.Error()))
-			}
-		}
-
-		if len(modeErrs) > 0 {
-			errs = append(errs, ErrorIn{
-				Name: name,
-				Err:  Errors(modeErrs),
-			})
-		}
 	}
 
 	if len(errs) > 0 {
-- 
GitLab