From 357d39036acf6ed6650a697ddc84df3b86b7d17a Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Thu, 14 Sep 2023 14:58:50 -0500
Subject: [PATCH] fail testing

---
 test/runner.go      | 69 +++++++++++++++++++++++++++++++++-----------
 test/tester_test.go | 70 ++++++++++++++++++++++++++-------------------
 2 files changed, 92 insertions(+), 47 deletions(-)

diff --git a/test/runner.go b/test/runner.go
index 5f985f1f..8ed80068 100644
--- a/test/runner.go
+++ b/test/runner.go
@@ -26,10 +26,11 @@ func MakeRunner(config Config, test Test) Runner {
 	}
 }
 
-func (T *Runner) prepare(client *gsql.Client) []Capturer {
-	results := make([]Capturer, len(T.test.Instructions))
+func (T *Runner) prepare(client *gsql.Client, until int) []Capturer {
+	results := make([]Capturer, until)
 
-	for i, x := range T.test.Instructions {
+	for i := 0; i < until; i++ {
+		x := T.test.Instructions[i]
 		switch v := x.(type) {
 		case inst.SimpleQuery:
 			q := packets.Query(v)
@@ -83,21 +84,15 @@ func (T *Runner) prepare(client *gsql.Client) []Capturer {
 	return results
 }
 
-func (T *Runner) runModeOnce(dialer dialer.Dialer) ([]Capturer, error) {
+func (T *Runner) runModeL1(dialer dialer.Dialer, client *gsql.Client) error {
 	server, _, err := dialer.Dial()
 	if err != nil {
-		return nil, err
+		return err
 	}
 	defer func() {
 		_ = server.Close()
 	}()
 
-	var client gsql.Client
-	results := T.prepare(&client)
-	if err = client.Close(); err != nil {
-		return nil, err
-	}
-
 	for {
 		var p fed.Packet
 		p, err = client.ReadPacket(true)
@@ -105,25 +100,55 @@ func (T *Runner) runModeOnce(dialer dialer.Dialer) ([]Capturer, error) {
 			if errors.Is(err, io.EOF) {
 				break
 			}
-			return nil, err
+			return err
 		}
 
-		clientErr, serverErr := bouncers.Bounce(&client, server, p)
+		clientErr, serverErr := bouncers.Bounce(client, server, p)
 		if clientErr != nil {
-			return nil, clientErr
+			return clientErr
 		}
 		if serverErr != nil {
-			return nil, serverErr
+			return serverErr
 		}
 	}
 
+	return nil
+}
+
+func (T *Runner) runModeOnce(dialer dialer.Dialer) ([]Capturer, error) {
+	var client gsql.Client
+	results := T.prepare(&client, len(T.test.Instructions))
+	if err := client.Close(); err != nil {
+		return nil, err
+	}
+
+	if err := T.runModeL1(dialer, &client); err != nil {
+		return nil, err
+	}
+
 	return results, nil
 }
 
+func (T *Runner) runModeFail(dialer dialer.Dialer) error {
+	for i := 1; i < len(T.test.Instructions)+1; i++ {
+		var client gsql.Client
+		T.prepare(&client, i)
+		if err := client.Close(); err != nil {
+			return err
+		}
+
+		if err := T.runModeL1(dialer, &client); err != nil && !errors.Is(err, io.EOF) {
+			return err
+		}
+	}
+
+	return nil
+}
+
 func (T *Runner) runMode(dialer dialer.Dialer) ([]Capturer, error) {
 	instances := T.config.Stress
 	if instances < 1 || T.test.SideEffects {
-		instances = 1
+		return T.runModeOnce(dialer)
 	}
 
 	expected, err := T.runModeOnce(dialer)
@@ -131,6 +156,12 @@ func (T *Runner) runMode(dialer dialer.Dialer) ([]Capturer, error) {
 		return nil, err
 	}
 
+	// fail testing
+	if err = T.runModeFail(dialer); err != nil {
+		return nil, err
+	}
+
+	// stress test
 	var b flip.Bank
 
 	for i := 0; i < instances-1; i++ {
@@ -152,7 +183,11 @@ func (T *Runner) runMode(dialer dialer.Dialer) ([]Capturer, error) {
 		})
 	}
 
-	return expected, b.Wait()
+	if err = b.Wait(); err != nil {
+		return nil, err
+	}
+
+	return expected, nil
 }
 
 func (T *Runner) Run() error {
diff --git a/test/tester_test.go b/test/tester_test.go
index 4b1fa6f2..6b32c603 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -8,6 +8,7 @@ import (
 	"strconv"
 	"testing"
 
+	"pggat/lib/auth"
 	"pggat/lib/auth/credentials"
 	"pggat/lib/bouncer/backends/v0"
 	"pggat/lib/bouncer/frontends/v0"
@@ -21,33 +22,8 @@ import (
 	"pggat/test/tests"
 )
 
-func TestTester(t *testing.T) {
-	control := dialer.Net{
-		Network: "tcp",
-		Address: "localhost:5432",
-		AcceptOptions: backends.AcceptOptions{
-			Credentials: credentials.Cleartext{
-				Username: "postgres",
-				Password: "password",
-			},
-			Database: "postgres",
-		},
-	}
-
-	// generate random password for testing
-	var raw [32]byte
-	_, err := rand.Read(raw[:])
-	if err != nil {
-		t.Error(err)
-		return
-	}
-	password := hex.EncodeToString(raw[:])
-	creds := credentials.Cleartext{
-		Username: "runner",
-		Password: password,
-	}
-
-	for i := 0; i < 70; i++ {
+func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, error) {
+	for i := 0; i < n; i++ {
 		var g gat.PoolsMap
 
 		var options = pool.Options{
@@ -68,15 +44,14 @@ func TestTester(t *testing.T) {
 
 		listener, err := gat.Listen("tcp", ":0", frontends.AcceptOptions{})
 		if err != nil {
-			t.Error(err)
-			return
+			return dialer.Net{}, err
 		}
 		port := listener.Listener.Addr().(*net.TCPAddr).Port
 
 		go func() {
 			err := gat.Serve(listener, &g)
 			if err != nil {
-				t.Error(err)
+				panic(err)
 			}
 		}()
 
@@ -90,6 +65,41 @@ func TestTester(t *testing.T) {
 		}
 	}
 
+	return control, nil
+}
+
+func TestTester(t *testing.T) {
+	control := dialer.Net{
+		Network: "tcp",
+		Address: "localhost:5432",
+		AcceptOptions: backends.AcceptOptions{
+			Credentials: credentials.Cleartext{
+				Username: "postgres",
+				Password: "password",
+			},
+			Database: "postgres",
+		},
+	}
+
+	// generate random password for testing
+	var raw [32]byte
+	_, err := rand.Read(raw[:])
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	password := hex.EncodeToString(raw[:])
+	creds := credentials.Cleartext{
+		Username: "runner",
+		Password: password,
+	}
+
+	control, err = daisyChain(creds, control, 16)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
 	var g gat.PoolsMap
 
 	transactionPool := pool.NewPool(transaction.Apply(pool.Options{
-- 
GitLab