From fc99137e1e1500272a862aabddd9edcf7d9e96c7 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Thu, 14 Sep 2023 14:21:06 -0500
Subject: [PATCH] catch discard all in eqp. deallocate and close will still
 break because there is not enough context

fixes #8
---
 lib/fed/packets/v3.0/commandcomplete.go | 19 ++++++++++++++++
 lib/middleware/middlewares/eqp/state.go | 19 ++++++++++++++++
 test/tester_test.go                     |  3 ++-
 test/tests/discard_all.go               | 30 +++++++++++++++++++++++++
 4 files changed, 70 insertions(+), 1 deletion(-)
 create mode 100644 lib/fed/packets/v3.0/commandcomplete.go
 create mode 100644 test/tests/discard_all.go

diff --git a/lib/fed/packets/v3.0/commandcomplete.go b/lib/fed/packets/v3.0/commandcomplete.go
new file mode 100644
index 00000000..d2f2cb86
--- /dev/null
+++ b/lib/fed/packets/v3.0/commandcomplete.go
@@ -0,0 +1,19 @@
+package packets
+
+import "pggat/lib/fed"
+
+type CommandComplete string
+
+func (T *CommandComplete) ReadFromPacket(packet fed.Packet) bool {
+	if packet.Type() != TypeCommandComplete {
+		return false
+	}
+	packet.ReadString((*string)(T))
+	return true
+}
+
+func (T *CommandComplete) IntoPacket() fed.Packet {
+	packet := fed.NewPacket(TypeCommandComplete, len(*T)+1)
+	packet = packet.AppendString(string(*T))
+	return packet
+}
diff --git a/lib/middleware/middlewares/eqp/state.go b/lib/middleware/middlewares/eqp/state.go
index ea122f63..7079c6b1 100644
--- a/lib/middleware/middlewares/eqp/state.go
+++ b/lib/middleware/middlewares/eqp/state.go
@@ -5,6 +5,7 @@ import (
 
 	"pggat/lib/fed"
 	packets "pggat/lib/fed/packets/v3.0"
+	"pggat/lib/util/maps"
 	"pggat/lib/util/ring"
 )
 
@@ -90,6 +91,8 @@ func (T *State) S2C(packet fed.Packet) {
 		T.ParseComplete()
 	case packets.TypeBindComplete:
 		T.BindComplete()
+	case packets.TypeCommandComplete:
+		T.CommandComplete(packet)
 	case packets.TypeReadyForQuery:
 		T.ReadyForQuery(packet)
 	}
@@ -179,6 +182,22 @@ func (T *State) Query() {
 	delete(T.preparedStatements, "")
 }
 
+// CommandComplete clobbers everything if DISCARD ALL | DEALLOCATE | CLOSE
+func (T *State) CommandComplete(packet fed.Packet) {
+	var commandComplete packets.CommandComplete
+	if !commandComplete.ReadFromPacket(packet) {
+		return
+	}
+
+	if commandComplete == "DISCARD ALL" {
+		maps.Clear(T.preparedStatements)
+		maps.Clear(T.portals)
+		T.pendingPreparedStatements.Clear()
+		T.pendingPortals.Clear()
+		T.pendingCloses.Clear()
+	}
+}
+
 // ReadyForQuery clobbers portals if state == 'I' and pending. Execute on ReadyForQuery S->C
 func (T *State) ReadyForQuery(packet fed.Packet) {
 	var state byte
diff --git a/test/tester_test.go b/test/tester_test.go
index 6d7468aa..cfdb1b47 100644
--- a/test/tester_test.go
+++ b/test/tester_test.go
@@ -47,7 +47,7 @@ func TestTester(t *testing.T) {
 		Password: password,
 	}
 
-	for i := 0; i < 1; i++ {
+	for i := 0; i < 10; i++ {
 		var g gat.PoolsMap
 		p := pool.NewPool(transaction.Apply(pool.Options{
 			Credentials: creds,
@@ -153,6 +153,7 @@ func TestTester(t *testing.T) {
 		tests.EQP8,
 		tests.CopyOut0,
 		tests.CopyOut1,
+		tests.DiscardAll,
 	); err != nil {
 		fmt.Print(err.Error())
 		t.Fail()
diff --git a/test/tests/discard_all.go b/test/tests/discard_all.go
new file mode 100644
index 00000000..4bd05651
--- /dev/null
+++ b/test/tests/discard_all.go
@@ -0,0 +1,30 @@
+package tests
+
+import (
+	"pggat/test"
+	"pggat/test/inst"
+)
+
+var DiscardAll = test.Test{
+	Name: "Discard All",
+	Instructions: []inst.Instruction{
+		inst.Parse{
+			Destination: "a",
+			Query:       "select 0",
+		},
+		inst.Bind{
+			Destination: "a",
+			Source:      "a",
+		},
+		inst.Sync{},
+		inst.SimpleQuery("discard all"),
+		inst.DescribePreparedStatement("a"),
+		inst.Sync{},
+		inst.Parse{
+			Destination: "a",
+			Query:       "select 0",
+		},
+		inst.DescribePreparedStatement("a"),
+		inst.Sync{},
+	},
+}
-- 
GitLab