From b5646107d201c4c9bca8dc11f16df30b74986ffc Mon Sep 17 00:00:00 2001
From: Garet Halliday <ghalliday@gfxlabs.io>
Date: Fri, 9 Sep 2022 15:21:31 -0500
Subject: [PATCH] run prepared statement on server chosen by query contents

---
 lib/gat/gatling/conn_pool/worker.go | 30 +++++++++++++++++++++++++----
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/lib/gat/gatling/conn_pool/worker.go b/lib/gat/gatling/conn_pool/worker.go
index f8e45bea..0a022f89 100644
--- a/lib/gat/gatling/conn_pool/worker.go
+++ b/lib/gat/gatling/conn_pool/worker.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"fmt"
 	"gfx.cafe/gfx/pggat/lib/config"
+	"gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error"
 	"log"
 
 	"gfx.cafe/gfx/pggat/lib/gat"
@@ -150,10 +151,31 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p
 		return fmt.Errorf("describe('%+v') fail: no server", payload)
 	}
 	defer srv.mu.Unlock()
-	// execute the query
-	// for now, use primary
-	// TODO read the query of the underlying prepared statement and choose server accordingly
-	target := srv.primary
+
+	// get the query text
+	portal := client.GetPortal(payload.Fields.Name)
+	if portal == nil {
+		return &pg_error.Error{
+			Severity: pg_error.Err,
+			Code:     pg_error.ProtocolViolation,
+			Message:  fmt.Sprintf("portal '%s' not found", payload.Fields.Name),
+		}
+	}
+
+	ps := client.GetPreparedStatement(portal.Fields.PreparedStatement)
+	if ps == nil {
+		return &pg_error.Error{
+			Severity: pg_error.Err,
+			Code:     pg_error.ProtocolViolation,
+			Message:  fmt.Sprintf("prepared statement '%s' not found", ps.Fields.PreparedStatement),
+		}
+	}
+
+	which, err := c.pool.GetRouter().InferRole(ps.Fields.Query)
+	if err != nil {
+		return err
+	}
+	target := srv.choose(which)
 	if target == nil {
 		return fmt.Errorf("describe('%+v') fail: no server", payload)
 	}
-- 
GitLab