From 7d0aa7a0d3ce06e6c538ec7a122c8a81e03685e9 Mon Sep 17 00:00:00 2001
From: Garet Halliday <ghalliday@gfxlabs.io>
Date: Fri, 30 Sep 2022 12:45:43 -0500
Subject: [PATCH] fix

---
 .../query_router/query_router_test.go         |   2 +-
 lib/gat/gatling/client/client.go              |  52 ++-
 lib/gat/gatling/server/server.go              | 312 ++++++++----------
 lib/gat/interfaces.go                         |   6 +-
 lib/gat/pool/session/pool.go                  |   6 +-
 lib/gat/pool/transaction/worker.go            |  26 +-
 6 files changed, 192 insertions(+), 212 deletions(-)

diff --git a/lib/gat/database/query_router/query_router_test.go b/lib/gat/database/query_router/query_router_test.go
index a68247b6..f94af3a9 100644
--- a/lib/gat/database/query_router/query_router_test.go
+++ b/lib/gat/database/query_router/query_router_test.go
@@ -8,7 +8,7 @@ import (
 
 // TODO: adapt tests
 func TestQueryRouterInterRoleReplica(t *testing.T) {
-	qr := DefaultRouter
+	qr := DefaultRouter(nil)
 	role, err := qr.InferRole(`UPDATE items SET name = 'pumpkin' WHERE id = 5`)
 	if err != nil {
 		t.Fatal(err)
diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go
index 20ac2ca5..d6d9163c 100644
--- a/lib/gat/gatling/client/client.go
+++ b/lib/gat/gatling/client/client.go
@@ -92,7 +92,6 @@ type Client struct {
 	statements  map[string]*protocol.Parse
 	portals     map[string]*protocol.Bind
 	conf        *config.Global
-	status      rune
 
 	parser *pg3p.Parser
 
@@ -181,7 +180,6 @@ func NewClient(
 		gatling:    gatling,
 		statements: make(map[string]*protocol.Parse),
 		portals:    make(map[string]*protocol.Bind),
-		status:     'I',
 		conf:       conf,
 		parser:     pg3p.NewParser(),
 	}
@@ -429,13 +427,11 @@ func (c *Client) Accept(ctx context.Context) error {
 				return err
 			}
 		}
-		if c.status == 'I' {
-			rq := new(protocol.ReadyForQuery)
-			rq.Fields.Status = 'I'
-			err = c.Send(rq)
-			if err != nil {
-				return err
-			}
+		rq := new(protocol.ReadyForQuery)
+		rq.Fields.Status = 'I'
+		err = c.Send(rq)
+		if err != nil {
+			return err
 		}
 	}
 	return nil
@@ -450,8 +446,23 @@ func (c *Client) recvLoop() {
 			}
 			break
 		}
-		log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv)
-		c.recv <- recv
+		//log.Printf("got packet(%s) %+v", reflect.TypeOf(recv), recv)
+		switch pkt := recv.(type) {
+		case *protocol.Parse:
+			c.statements[pkt.Fields.PreparedStatement] = pkt
+			err = c.Send(new(protocol.ParseComplete))
+			if err != nil {
+				break
+			}
+		case *protocol.Bind:
+			c.portals[pkt.Fields.Destination] = pkt
+			err = c.Send(new(protocol.BindComplete))
+			if err != nil {
+				break
+			}
+		default:
+			c.recv <- recv
+		}
 	}
 }
 
@@ -479,16 +490,11 @@ func (c *Client) tick(ctx context.Context) (bool, error) {
 		return false, ctx.Err()
 	}
 	switch cast := rsp.(type) {
-	case *protocol.Parse:
-		return true, c.parse(ctx, cast)
-	case *protocol.Bind:
-		return true, c.bind(ctx, cast)
 	case *protocol.Describe:
 		return true, c.handle_describe(ctx, cast)
 	case *protocol.Execute:
 		return true, c.handle_execute(ctx, cast)
 	case *protocol.Sync:
-		c.status = 'I'
 		return true, nil
 	case *protocol.Query:
 		return true, c.handle_query(ctx, cast)
@@ -502,28 +508,14 @@ func (c *Client) tick(ctx context.Context) (bool, error) {
 	return true, nil
 }
 
-func (c *Client) parse(ctx context.Context, q *protocol.Parse) error {
-	c.statements[q.Fields.PreparedStatement] = q
-	c.status = 'T'
-	return c.Send(new(protocol.ParseComplete))
-}
-
-func (c *Client) bind(ctx context.Context, b *protocol.Bind) error {
-	c.portals[b.Fields.Destination] = b
-	c.status = 'T'
-	return c.Send(new(protocol.BindComplete))
-}
-
 func (c *Client) handle_describe(ctx context.Context, d *protocol.Describe) error {
 	//log.Println("describe")
-	c.status = 'T'
 	c.startRequest()
 	return c.server.Describe(ctx, c, d)
 }
 
 func (c *Client) handle_execute(ctx context.Context, e *protocol.Execute) error {
 	//log.Println("execute")
-	c.status = 'T'
 	c.startRequest()
 	return c.server.Execute(ctx, c, e)
 }
diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go
index 667991cd..7ef62ee3 100644
--- a/lib/gat/gatling/server/server.go
+++ b/lib/gat/gatling/server/server.go
@@ -2,7 +2,6 @@ package server
 
 import (
 	"bufio"
-	"errors"
 	"fmt"
 	"net"
 	"reflect"
@@ -53,7 +52,8 @@ type Server struct {
 	dbpass string
 	user   config.User
 
-	healthy bool
+	healthy      bool
+	awaitingSync bool
 
 	log zlog.Logger
 
@@ -387,6 +387,14 @@ func (s *Server) readPacket() (protocol.Packet, error) {
 	return p, err
 }
 
+func (s *Server) stabilize() {
+	// TODO actually stabilize connection
+	if s.awaitingSync {
+		_ = s.writePacket(new(protocol.Sync))
+		_ = s.flush()
+	}
+}
+
 func (s *Server) ensurePreparedStatement(client gat.Client, name string) error {
 	// send prepared statement
 	stmt := client.GetPreparedStatement(name)
@@ -471,183 +479,174 @@ func (s *Server) destructPortal(name string) {
 	s.destructPreparedStatement(portal.Fields.PreparedStatement)
 }
 
-func (s *Server) Describe(client gat.Client, d *protocol.Describe) error {
-	switch d.Fields.Which {
-	case 'S': // prepared statement
-		err := s.ensurePreparedStatement(client, d.Fields.Name)
+func (s *Server) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error {
+	return s.sendAndLink(ctx, client, d)
+}
+
+func (s *Server) handleRecv(client gat.Client, packet protocol.Packet) error {
+	switch pkt := packet.(type) {
+	case *protocol.FunctionCall, *protocol.Query:
+		err := s.writePacket(packet)
 		if err != nil {
 			return err
 		}
-	case 'P': // portal
-		err := s.ensurePortal(client, d.Fields.Name)
+		err = s.flush()
 		if err != nil {
 			return err
 		}
-	default:
-		return &pg_error.Error{
-			Severity: pg_error.Err,
-			Code:     pg_error.ProtocolViolation,
-			Message:  fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", d.Fields.Which),
+	case *protocol.Describe:
+		s.awaitingSync = true
+		switch pkt.Fields.Which {
+		case 'S': // prepared statement
+			err := s.ensurePreparedStatement(client, pkt.Fields.Name)
+			if err != nil {
+				return err
+			}
+		case 'P': // portal
+			err := s.ensurePortal(client, pkt.Fields.Name)
+			if err != nil {
+				return err
+			}
+		default:
+			return &pg_error.Error{
+				Severity: pg_error.Err,
+				Code:     pg_error.ProtocolViolation,
+				Message:  fmt.Sprintf("expected 'S' or 'P' for describe target, got '%c'", pkt.Fields.Which),
+			}
 		}
-	}
 
-	// now we actually execute the thing the client wants
-	err := s.writePacket(d)
-	if err != nil {
-		return err
+		// now we actually execute the thing the client wants
+		err := s.writePacket(packet)
+		if err != nil {
+			return err
+		}
+	case *protocol.Execute:
+		s.awaitingSync = true
+		err := s.ensurePortal(client, pkt.Fields.Name)
+		if err != nil {
+			return err
+		}
+
+		err = s.writePacket(pkt)
+		if err != nil {
+			return err
+		}
+	case *protocol.Sync:
+		s.awaitingSync = false
+		err := s.writePacket(packet)
+		if err != nil {
+			return err
+		}
+		err = s.flush()
+		if err != nil {
+			return err
+		}
 	}
-	err = s.writePacket(new(protocol.Sync))
+	return nil
+}
+
+func (s *Server) sendAndLink(ctx context.Context, client gat.Client, initial protocol.Packet) error {
+	err := s.handleRecv(client, initial)
 	if err != nil {
 		return err
 	}
-	err = s.flush()
+	err = s.awaitSync(ctx, client)
 	if err != nil {
 		return err
 	}
+	return s.link(ctx, client)
+}
 
-	return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
-		//log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt)
-		switch pkt.(type) {
+func (s *Server) link(ctx context.Context, client gat.Client) error {
+	defer s.stabilize()
+	for {
+		pkt, err := s.readPacket()
+		if err != nil {
+			return err
+		}
+
+		switch p := pkt.(type) {
 		case *protocol.BindComplete, *protocol.ParseComplete:
+			// ignore, it is because we bound stuff
 		case *protocol.ReadyForQuery:
-			finish = true
+			if p.Fields.Status == 'I' {
+				// this client is done
+				return nil
+			}
+
+			err = client.Send(p)
+			if err != nil {
+				return err
+			}
+			err = client.Flush()
+			if err != nil {
+				return err
+			}
+
+			err = s.handleClientPacket(ctx, client)
+			if err != nil {
+				return err
+			}
+			err = s.awaitSync(ctx, client)
+			if err != nil {
+				return err
+			}
+		case *protocol.CopyInResponse:
+			err = client.Send(p)
+			if err != nil {
+				return err
+			}
+			err = client.Flush()
+			if err != nil {
+				return err
+			}
+			err = s.CopyIn(ctx, client)
+			if err != nil {
+				return err
+			}
 		default:
-			forward = true
+			err = client.Send(p)
+			if err != nil {
+				return err
+			}
 		}
-		return
-	})
+	}
 }
 
-func (s *Server) Execute(client gat.Client, e *protocol.Execute) error {
-	log.Printf("execute `%s`", e.Fields.Name)
-	err := s.ensurePortal(client, e.Fields.Name)
-	if err != nil {
-		return err
+func (s *Server) handleClientPacket(ctx context.Context, client gat.Client) error {
+	select {
+	case pkt := <-client.Recv():
+		return s.handleRecv(client, pkt)
+	case <-ctx.Done():
+		return ctx.Err()
 	}
+}
 
-	err = s.writePacket(e)
-	if err != nil {
-		return err
-	}
-	err = s.writePacket(new(protocol.Sync))
-	if err != nil {
-		return err
-	}
-	err = s.flush()
-	if err != nil {
-		return err
+func (s *Server) awaitSync(ctx context.Context, client gat.Client) error {
+	for s.awaitingSync {
+		err := s.handleClientPacket(ctx, client)
+		if err != nil {
+			return err
+		}
 	}
+	return nil
+}
 
-	return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
-		//log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt)
-		switch p := pkt.(type) {
-		case *protocol.BindComplete, *protocol.ParseComplete:
-		case *protocol.ReadyForQuery:
-			if p.Fields.Status != 'I' {
-				err = errors.New("transactions are not allowed in statements")
-
-				end := new(protocol.Query)
-				end.Fields.Query = "END"
-				_ = s.writePacket(end)
-				_ = s.flush()
-			} else {
-				finish = true
-			}
-		default:
-			forward = true
-		}
-		return
-	})
+func (s *Server) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error {
+	return s.sendAndLink(ctx, client, e)
 }
 
 func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query string) error {
 	// send to server
 	q := new(protocol.Query)
 	q.Fields.Query = query
-	err := s.writePacket(q)
-	if err != nil {
-		return err
-	}
-	err = s.flush()
-	if err != nil {
-		return err
-	}
-
-	// this function seems wild but it has to be the way it is so we read the whole response, even if the
-	// client fails midway
-	// read responses
-	return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
-		//log.Printf("forwarding pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt)
-		switch pkt.(type) {
-		case *protocol.ReadyForQuery:
-			// all ReadyForQuery packets end a simple query, regardless of type
-			finish = true
-		case *protocol.CopyInResponse:
-			_ = client.Send(pkt)
-			err = s.CopyIn(ctx, client)
-		default:
-			forward = true
-		}
-		return
-	})
+	return s.sendAndLink(ctx, client, q)
 }
 
 func (s *Server) Transaction(ctx context.Context, client gat.Client, query string) error {
 	q := new(protocol.Query)
 	q.Fields.Query = query
-	err := s.writePacket(q)
-	if err != nil {
-		return err
-	}
-	err = s.flush()
-	if err != nil {
-		return err
-	}
-	return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
-		//log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt)
-		switch p := pkt.(type) {
-		case *protocol.ReadyForQuery:
-			if p.Fields.Status != 'I' {
-				// send to client and wait for next query
-				err = client.Send(pkt)
-
-				if err == nil {
-					err = client.Flush()
-					if err == nil {
-						select {
-						case r := <-client.Recv():
-							//log.Printf("got client pkt pkt(%s): %+v", reflect.TypeOf(r), r)
-							switch r.(type) {
-							case *protocol.Query:
-								//forward to server
-								_ = s.writePacket(r)
-								_ = s.flush()
-							default:
-								err = fmt.Errorf("expected a query in transaction state but got something else")
-							}
-						case <-ctx.Done():
-							err = ctx.Err()
-						}
-					}
-				}
-
-				if err != nil {
-					end := new(protocol.Query)
-					end.Fields.Query = "END"
-					_ = s.writePacket(end)
-					_ = s.flush()
-				}
-			} else {
-				finish = true
-			}
-		case *protocol.CopyInResponse:
-			_ = client.Send(pkt)
-			err = s.CopyIn(ctx, client)
-		default:
-			forward = true
-		}
-		return
-	})
+	return s.sendAndLink(ctx, client, q)
 }
 
 func (s *Server) CopyIn(ctx context.Context, client gat.Client) error {
@@ -656,19 +655,15 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error {
 		return err
 	}
 	for {
-		// detect a disconneted /hanging client by waiting 30 seoncds, else timeout
-		// otherwise, just keep reading packets until a done or error is received
-		cctx, cancel := context.WithTimeout(ctx, 30*time.Second)
 		var pkt protocol.Packet
 		// receive a packet, or done if the ctx gets canceled
 		select {
 		case pkt = <-client.Recv():
-		case <-cctx.Done():
+		case <-ctx.Done():
 			_ = s.writePacket(new(protocol.CopyFail))
 			_ = s.flush()
-			return cctx.Err()
+			return ctx.Err()
 		}
-		cancel()
 		err = s.writePacket(pkt)
 		if err != nil {
 			return err
@@ -682,25 +677,8 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error {
 	}
 }
 
-func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) error {
-	err := s.writePacket(payload)
-	if err != nil {
-		return err
-	}
-	err = s.flush()
-	if err != nil {
-		return err
-	}
-	// read responses
-	return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
-		switch pkt.(type) {
-		case *protocol.ReadyForQuery: // status 'I' should only be encountered here
-			finish = true
-		default:
-			forward = true
-		}
-		return
-	})
+func (s *Server) CallFunction(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error {
+	return s.sendAndLink(ctx, client, payload)
 }
 
 func (s *Server) Close(ctx context.Context) error {
diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go
index a8af8c35..fc4ab08e 100644
--- a/lib/gat/interfaces.go
+++ b/lib/gat/interfaces.go
@@ -130,9 +130,9 @@ type Connection interface {
 	IsCloseNeeded() bool
 
 	// actions
-	Describe(client Client, payload *protocol.Describe) error
-	Execute(client Client, payload *protocol.Execute) error
-	CallFunction(client Client, payload *protocol.FunctionCall) error
+	Describe(ctx context.Context, client Client, payload *protocol.Describe) error
+	Execute(ctx context.Context, client Client, payload *protocol.Execute) error
+	CallFunction(ctx context.Context, client Client, payload *protocol.FunctionCall) error
 	SimpleQuery(ctx context.Context, client Client, payload string) error
 	Transaction(ctx context.Context, client Client, payload string) error
 
diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go
index 24986a8a..f58a256b 100644
--- a/lib/gat/pool/session/pool.go
+++ b/lib/gat/pool/session/pool.go
@@ -98,7 +98,7 @@ func (p *Pool) Describe(ctx context.Context, client gat.Client, describe *protoc
 	if err != nil {
 		return err
 	}
-	return c.Describe(client, describe)
+	return c.Describe(ctx, client, describe)
 }
 
 func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol.Execute) error {
@@ -106,7 +106,7 @@ func (p *Pool) Execute(ctx context.Context, client gat.Client, execute *protocol
 	if err != nil {
 		return err
 	}
-	return c.Execute(client, execute)
+	return c.Execute(ctx, client, execute)
 }
 
 func (p *Pool) SimpleQuery(ctx context.Context, client gat.Client, query string) error {
@@ -130,7 +130,7 @@ func (p *Pool) CallFunction(ctx context.Context, client gat.Client, payload *pro
 	if err != nil {
 		return err
 	}
-	return c.CallFunction(client, payload)
+	return c.CallFunction(ctx, client, payload)
 }
 
 var _ gat.Pool = (*Pool)(nil)
diff --git a/lib/gat/pool/transaction/worker.go b/lib/gat/pool/transaction/worker.go
index 94a3083a..f1ff43d9 100644
--- a/lib/gat/pool/transaction/worker.go
+++ b/lib/gat/pool/transaction/worker.go
@@ -104,7 +104,9 @@ func (w *worker) HandleDescribe(ctx context.Context, c gat.Client, d *protocol.D
 	defer w.ret()
 
 	if w.w.user.StatementTimeout != 0 {
-		ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		var done context.CancelFunc
+		ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		defer done()
 	}
 
 	errch := make(chan error)
@@ -128,7 +130,9 @@ func (w *worker) HandleExecute(ctx context.Context, c gat.Client, e *protocol.Ex
 	defer w.ret()
 
 	if w.w.user.StatementTimeout != 0 {
-		ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		var done context.CancelFunc
+		ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		defer done()
 	}
 
 	errch := make(chan error)
@@ -152,7 +156,9 @@ func (w *worker) HandleFunction(ctx context.Context, c gat.Client, fn *protocol.
 	defer w.ret()
 
 	if w.w.user.StatementTimeout != 0 {
-		ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		var done context.CancelFunc
+		ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		defer done()
 	}
 
 	errch := make(chan error)
@@ -176,7 +182,9 @@ func (w *worker) HandleSimpleQuery(ctx context.Context, c gat.Client, query stri
 	defer w.ret()
 
 	if w.w.user.StatementTimeout != 0 {
-		ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		var done context.CancelFunc
+		ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		defer done()
 	}
 
 	start := time.Now()
@@ -206,7 +214,9 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri
 	defer w.ret()
 
 	if w.w.user.StatementTimeout != 0 {
-		ctx, _ = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		var done context.CancelFunc
+		ctx, done = context.WithTimeout(ctx, time.Duration(w.w.user.StatementTimeout)*time.Millisecond)
+		defer done()
 	}
 
 	start := time.Now()
@@ -259,7 +269,7 @@ func (w *worker) z_actually_do_describe(ctx context.Context, client gat.Client,
 	}
 	w.setCurrentBinding(client, target)
 	defer w.unsetCurrentBinding(client, target)
-	return target.Describe(client, payload)
+	return target.Describe(ctx, client, payload)
 }
 func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, payload *protocol.Execute) error {
 	srv := w.chooseShard(client)
@@ -299,7 +309,7 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p
 	if target == nil {
 		return fmt.Errorf("describe('%+v') fail: no server", payload)
 	}
-	return target.Execute(client, payload)
+	return target.Execute(ctx, client, payload)
 }
 func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error {
 	srv := w.chooseShard(client)
@@ -316,7 +326,7 @@ func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payloa
 	}
 	w.setCurrentBinding(client, target)
 	defer w.unsetCurrentBinding(client, target)
-	err := target.CallFunction(client, payload)
+	err := target.CallFunction(ctx, client, payload)
 	if err != nil {
 		return fmt.Errorf("fn('%+v') fail: %w ", payload, err)
 	}
-- 
GitLab