From df87b7cfcbb74f7def6e809842cec828a51966df Mon Sep 17 00:00:00 2001
From: Tom Guinther <tguinther@gfxlabs.io>
Date: Fri, 26 Jul 2024 11:49:51 -0400
Subject: [PATCH] add context

---
 lib/fed/listeners/netconnlistener/listener.go    |  3 ++-
 lib/fed/middleware.go                            | 10 ++++++----
 lib/fed/middlewares/eqp/client.go                |  9 +++++----
 lib/fed/middlewares/eqp/server.go                |  9 +++++----
 lib/fed/middlewares/ps/client.go                 |  9 +++++----
 lib/fed/middlewares/ps/server.go                 |  9 +++++----
 lib/fed/middlewares/unterminate/unterminate.go   |  9 +++++----
 lib/gat/handlers/pool/dialer.go                  |  5 +++--
 lib/gat/handlers/pool/pools/hybrid/middleware.go |  9 +++++----
 lib/gsql/pair.go                                 |  3 +++
 lib/gsql/query_test.go                           |  3 ++-
 11 files changed, 46 insertions(+), 32 deletions(-)

diff --git a/lib/fed/listeners/netconnlistener/listener.go b/lib/fed/listeners/netconnlistener/listener.go
index 0d6c0c55..6ce55f35 100644
--- a/lib/fed/listeners/netconnlistener/listener.go
+++ b/lib/fed/listeners/netconnlistener/listener.go
@@ -1,6 +1,7 @@
 package netconnlistener
 
 import (
+	"context"
 	"net"
 
 	"gfx.cafe/gfx/pggat/lib/fed"
@@ -16,7 +17,7 @@ func (listener *Listener) Accept(fn func(*fed.Conn)) error {
 	if err != nil {
 		return err
 	}
-	fedConn := fed.NewConn(netconncodec.NewCodec(raw))
+	fedConn := fed.NewConn(context.Background(), netconncodec.NewCodec(raw))
 	go func() {
 		fn(fedConn)
 	}()
diff --git a/lib/fed/middleware.go b/lib/fed/middleware.go
index 92e31bd8..1d9668c2 100644
--- a/lib/fed/middleware.go
+++ b/lib/fed/middleware.go
@@ -1,12 +1,14 @@
 package fed
 
+import "context"
+
 // Middleware intercepts packets and possibly changes them. Return a 0 length packet to cancel.
 type Middleware interface {
-	PreRead(typed bool) (Packet, error)
-	ReadPacket(packet Packet) (Packet, error)
+	PreRead(ctx context.Context, typed bool) (Packet, error)
+	ReadPacket(ctx context.Context, packet Packet) (Packet, error)
 
-	WritePacket(packet Packet) (Packet, error)
-	PostWrite() (Packet, error)
+	WritePacket(ctx context.Context, packet Packet) (Packet, error)
+	PostWrite(ctx context.Context) (Packet, error)
 }
 
 func LookupMiddleware[T Middleware](conn *Conn) (T, bool) {
diff --git a/lib/fed/middlewares/eqp/client.go b/lib/fed/middlewares/eqp/client.go
index 652d7073..0476f1b9 100644
--- a/lib/fed/middlewares/eqp/client.go
+++ b/lib/fed/middlewares/eqp/client.go
@@ -1,6 +1,7 @@
 package eqp
 
 import (
+	"context"
 	"gfx.cafe/gfx/pggat/lib/fed"
 )
 
@@ -12,19 +13,19 @@ func NewClient() *Client {
 	return new(Client)
 }
 
-func (T *Client) PreRead(_ bool) (fed.Packet, error) {
+func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) {
 	return nil, nil
 }
 
-func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return T.state.C2S(packet)
 }
 
-func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return T.state.S2C(packet)
 }
 
-func (T *Client) PostWrite() (fed.Packet, error) {
+func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) {
 	return nil, nil
 }
 
diff --git a/lib/fed/middlewares/eqp/server.go b/lib/fed/middlewares/eqp/server.go
index f847a9fa..676fe166 100644
--- a/lib/fed/middlewares/eqp/server.go
+++ b/lib/fed/middlewares/eqp/server.go
@@ -1,6 +1,7 @@
 package eqp
 
 import (
+	"context"
 	"gfx.cafe/gfx/pggat/lib/fed"
 )
 
@@ -12,19 +13,19 @@ func NewServer() *Server {
 	return new(Server)
 }
 
-func (T *Server) PreRead(_ bool) (fed.Packet, error) {
+func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) {
 	return nil, nil
 }
 
-func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Server) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return T.state.S2C(packet)
 }
 
-func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Server) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return T.state.C2S(packet)
 }
 
-func (T *Server) PostWrite() (fed.Packet, error) {
+func (T *Server) PostWrite(ctx context.Context, ) (fed.Packet, error) {
 	return nil, nil
 }
 
diff --git a/lib/fed/middlewares/ps/client.go b/lib/fed/middlewares/ps/client.go
index 94b5014f..d7f383b6 100644
--- a/lib/fed/middlewares/ps/client.go
+++ b/lib/fed/middlewares/ps/client.go
@@ -1,6 +1,7 @@
 package ps
 
 import (
+	"context"
 	"gfx.cafe/gfx/pggat/lib/fed"
 	packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
 	"gfx.cafe/gfx/pggat/lib/util/maps"
@@ -18,15 +19,15 @@ func NewClient(parameters map[strutil.CIString]string) *Client {
 	}
 }
 
-func (T *Client) PreRead(_ bool) (fed.Packet, error) {
+func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) {
 	return nil, nil
 }
 
-func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return packet, nil
 }
 
-func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	switch packet.Type() {
 	case packets.TypeParameterStatus:
 		var p packets.ParameterStatus
@@ -49,7 +50,7 @@ func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) {
 	}
 }
 
-func (T *Client) PostWrite() (fed.Packet, error) {
+func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) {
 	return nil, nil
 }
 
diff --git a/lib/fed/middlewares/ps/server.go b/lib/fed/middlewares/ps/server.go
index 137b34df..2f922850 100644
--- a/lib/fed/middlewares/ps/server.go
+++ b/lib/fed/middlewares/ps/server.go
@@ -1,6 +1,7 @@
 package ps
 
 import (
+	"context"
 	"gfx.cafe/gfx/pggat/lib/fed"
 	packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
 	"gfx.cafe/gfx/pggat/lib/util/strutil"
@@ -16,11 +17,11 @@ func NewServer(parameters map[strutil.CIString]string) *Server {
 	}
 }
 
-func (T *Server) PreRead(_ bool) (fed.Packet, error) {
+func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) {
 	return nil, nil
 }
 
-func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Server) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	switch packet.Type() {
 	case packets.TypeParameterStatus:
 		var p packets.ParameterStatus
@@ -39,11 +40,11 @@ func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) {
 	}
 }
 
-func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Server) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return packet, nil
 }
 
-func (T *Server) PostWrite() (fed.Packet, error) {
+func (T *Server) PostWrite(ctx context.Context) (fed.Packet, error) {
 	return nil, nil
 }
 
diff --git a/lib/fed/middlewares/unterminate/unterminate.go b/lib/fed/middlewares/unterminate/unterminate.go
index 073dbe10..633de635 100644
--- a/lib/fed/middlewares/unterminate/unterminate.go
+++ b/lib/fed/middlewares/unterminate/unterminate.go
@@ -1,6 +1,7 @@
 package unterminate
 
 import (
+	"context"
 	"io"
 
 	"gfx.cafe/gfx/pggat/lib/fed"
@@ -13,22 +14,22 @@ var Unterminate = unterm{}
 
 type unterm struct{}
 
-func (unterm) PreRead(_ bool) (fed.Packet, error) {
+func (unterm) PreRead(ctx context.Context, _ bool) (fed.Packet, error) {
 	return nil, nil
 }
 
-func (unterm) ReadPacket(packet fed.Packet) (fed.Packet, error) {
+func (unterm) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	if packet.Type() == packets.TypeTerminate {
 		return packet, io.EOF
 	}
 	return packet, nil
 }
 
-func (unterm) WritePacket(packet fed.Packet) (fed.Packet, error) {
+func (unterm) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	return packet, nil
 }
 
-func (unterm) PostWrite() (fed.Packet, error) {
+func (unterm) PostWrite(ctx context.Context, ) (fed.Packet, error) {
 	return nil, nil
 }
 
diff --git a/lib/gat/handlers/pool/dialer.go b/lib/gat/handlers/pool/dialer.go
index 1da5d9fe..49ec224a 100644
--- a/lib/gat/handlers/pool/dialer.go
+++ b/lib/gat/handlers/pool/dialer.go
@@ -1,6 +1,7 @@
 package pool
 
 import (
+	"context"
 	"crypto/tls"
 	"encoding/json"
 	"fmt"
@@ -66,7 +67,7 @@ func (T *Dialer) Dial() (*fed.Conn, error) {
 	if err != nil {
 		return nil, err
 	}
-	conn := fed.NewConn(netconncodec.NewCodec(c))
+	conn := fed.NewConn(context.Background(), netconncodec.NewCodec(c))
 	conn.User = T.Username
 	conn.Database = T.Database
 	err = backends.Accept(
@@ -90,7 +91,7 @@ func (T *Dialer) Cancel(key fed.BackendKey) {
 	if err != nil {
 		return
 	}
-	conn := fed.NewConn(netconncodec.NewCodec(c))
+	conn := fed.NewConn(context.Background(), netconncodec.NewCodec(c))
 	defer func() {
 		_ = conn.Close()
 	}()
diff --git a/lib/gat/handlers/pool/pools/hybrid/middleware.go b/lib/gat/handlers/pool/pools/hybrid/middleware.go
index 230ab054..ac6ff340 100644
--- a/lib/gat/handlers/pool/pools/hybrid/middleware.go
+++ b/lib/gat/handlers/pool/pools/hybrid/middleware.go
@@ -1,6 +1,7 @@
 package hybrid
 
 import (
+	"context"
 	"gfx.cafe/gfx/pggat/lib/fed"
 	packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
 	"gfx.cafe/gfx/pggat/lib/perror"
@@ -20,7 +21,7 @@ func NewMiddleware() *Middleware {
 	return m
 }
 
-func (T *Middleware) PreRead(typed bool) (fed.Packet, error) {
+func (T *Middleware) PreRead(ctx context.Context, typed bool) (fed.Packet, error) {
 	if !T.primary {
 		return nil, nil
 	}
@@ -37,7 +38,7 @@ func (T *Middleware) PreRead(typed bool) (fed.Packet, error) {
 	}, nil
 }
 
-func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Middleware) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	if T.primary {
 		return packet, nil
 	}
@@ -60,7 +61,7 @@ func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) {
 	return p, nil
 }
 
-func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) {
+func (T *Middleware) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) {
 	if T.primary && (T.buf.Buffered() > 0 || T.bufDec.Buffered() > 0) {
 		return nil, nil
 	}
@@ -84,7 +85,7 @@ func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) {
 	return packet, nil
 }
 
-func (T *Middleware) PostWrite() (fed.Packet, error) {
+func (T *Middleware) PostWrite(ctx context.Context) (fed.Packet, error) {
 	return nil, nil
 }
 
diff --git a/lib/gsql/pair.go b/lib/gsql/pair.go
index 2fcaff95..857f8fe1 100644
--- a/lib/gsql/pair.go
+++ b/lib/gsql/pair.go
@@ -1,6 +1,7 @@
 package gsql
 
 import (
+	"context"
 	"net"
 
 	"gfx.cafe/gfx/pggat/lib/fed"
@@ -13,12 +14,14 @@ func NewPair() (*fed.Conn, *fed.Conn, net.Conn, net.Conn) {
 	in := mio.InwardConn{Conn: conn}
 	out := mio.OutwardConn{Conn: conn}
 	inward := fed.NewConn(
+		context.Background(),
 		netconncodec.NewCodec(
 			in,
 		),
 	)
 	inward.Ready = true
 	outward := fed.NewConn(
+		context.Background(),
 		netconncodec.NewCodec(
 			out,
 		),
diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go
index 476748de..0859a6ba 100644
--- a/lib/gsql/query_test.go
+++ b/lib/gsql/query_test.go
@@ -1,6 +1,7 @@
 package gsql_test
 
 import (
+	"context"
 	"crypto/tls"
 	"log"
 	"net"
@@ -32,7 +33,7 @@ func TestQuery(t *testing.T) {
 		t.Error(err)
 		return
 	}
-	server := fed.NewConn(netconncodec.NewCodec(s))
+	server := fed.NewConn(context.Background(), netconncodec.NewCodec(s))
 	err = backends.Accept(
 		server,
 		"disable",
-- 
GitLab