diff --git a/lib/gat/gatling/conn_pool/worker.go b/lib/gat/gatling/conn_pool/worker.go index f8e45bea30173646356c02a3d10d2e94346a5b6b..0a022f89a0209f4101d32aedd50b618d7f1ad5a3 100644 --- a/lib/gat/gatling/conn_pool/worker.go +++ b/lib/gat/gatling/conn_pool/worker.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "gfx.cafe/gfx/pggat/lib/config" + "gfx.cafe/gfx/pggat/lib/gat/protocol/pg_error" "log" "gfx.cafe/gfx/pggat/lib/gat" @@ -150,10 +151,31 @@ func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, p return fmt.Errorf("describe('%+v') fail: no server", payload) } defer srv.mu.Unlock() - // execute the query - // for now, use primary - // TODO read the query of the underlying prepared statement and choose server accordingly - target := srv.primary + + // get the query text + portal := client.GetPortal(payload.Fields.Name) + if portal == nil { + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("portal '%s' not found", payload.Fields.Name), + } + } + + ps := client.GetPreparedStatement(portal.Fields.PreparedStatement) + if ps == nil { + return &pg_error.Error{ + Severity: pg_error.Err, + Code: pg_error.ProtocolViolation, + Message: fmt.Sprintf("prepared statement '%s' not found", ps.Fields.PreparedStatement), + } + } + + which, err := c.pool.GetRouter().InferRole(ps.Fields.Query) + if err != nil { + return err + } + target := srv.choose(which) if target == nil { return fmt.Errorf("describe('%+v') fail: no server", payload) }