diff --git a/lib/gat/conn_pool.go b/lib/gat/conn_pool.go index 5babf9219d98abfe844f34f5d7ccab5d26739e95..b88f55fe3c677eec6532683e4a20b8f828ae414f 100644 --- a/lib/gat/conn_pool.go +++ b/lib/gat/conn_pool.go @@ -10,4 +10,5 @@ type ConnectionPool interface { GetUser() *config.User GetServerInfo() []*protocol.ParameterStatus Query(ctx context.Context, query string) (<-chan protocol.Packet, error) + CallFunction(ctx context.Context, payload *protocol.FunctionCall) (<-chan protocol.Packet, error) } diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 6ea1a887d5c39b256d2ee563795345a7acfe14d7..2ad1fa6c1a63dda117058e880f5965bf70ac58e0 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -292,6 +292,8 @@ func (c *Client) tick(ctx context.Context) (bool, error) { switch cast := rsp.(type) { case *protocol.Query: return true, c.handle_query(ctx, cast) + case *protocol.FunctionCall: + return true, c.handle_function(ctx, cast) case *protocol.Terminate: return false, nil default: @@ -299,22 +301,33 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return true, nil } -func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { - rep, err := c.server.Query(ctx, q.Fields.Query) - if err != nil { - return err - } +func (c *Client) forward(pkts <-chan protocol.Packet) error { for { - rsp := <-rep + rsp := <-pkts if rsp == nil { - break + return nil } - err = c.Send(rsp) + err := c.Send(rsp) if err != nil { return err } } - return nil +} + +func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { + rep, err := c.server.Query(ctx, q.Fields.Query) + if err != nil { + return err + } + return c.forward(rep) +} + +func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) error { + rep, err := c.server.CallFunction(ctx, f) + if err != nil { + return err + } + return c.forward(rep) } /* diff --git a/lib/gat/gatling/conn_pool/conn_pool.go b/lib/gat/gatling/conn_pool/conn_pool.go index 0b1bc70b4c8d34d7be5a995de84e022e74cb2ed4..6071e64c38ad31023f28a773670194ff5e2ee2cc 100644 --- a/lib/gat/gatling/conn_pool/conn_pool.go +++ b/lib/gat/gatling/conn_pool/conn_pool.go @@ -14,9 +14,9 @@ import ( "sync" ) -type query struct { - query string - rep chan<- protocol.Packet +type request[T any] struct { + payload T + rep chan<- protocol.Packet } type servers struct { @@ -34,20 +34,22 @@ type shard struct { } type ConnectionPool struct { - c *config.Pool - user *config.User - pool gat.Pool - shards []shard - queries chan query + c *config.Pool + user *config.User + pool gat.Pool + shards []shard + queries chan request[string] + functionCalls chan request[*protocol.FunctionCall] mu sync.RWMutex } func NewConnectionPool(pool gat.Pool, conf *config.Pool, user *config.User) *ConnectionPool { p := &ConnectionPool{ - user: user, - pool: pool, - queries: make(chan query), + user: user, + pool: pool, + queries: make(chan request[string]), + functionCalls: make(chan request[*protocol.FunctionCall]), } p.EnsureConfig(conf) for i := 0; i < user.PoolSize; i++ { @@ -133,22 +135,38 @@ func (c *ConnectionPool) chooseServer(query string) *servers { func (c *ConnectionPool) worker() { for { - q := <-c.queries - - srv := c.chooseServer(q.query) - if srv == nil { - log.Printf("call to query '%s' failed", q.query) - continue + select { + case q := <-c.queries: + srv := c.chooseServer(q.payload) + if srv == nil { + log.Printf("call to query '%s' failed", q.payload) + continue + } + + // run the query + err := srv.primary.Query(q.payload, q.rep) + srv.mu.Unlock() + + if err != nil { + log.Println(err) + } + close(q.rep) + case f := <-c.functionCalls: + srv := c.chooseServer("") + if srv == nil { + log.Printf("function call '%+v' failed", f.payload) + continue + } + + // run the query + err := srv.primary.CallFunction(f.payload, f.rep) + srv.mu.Unlock() + + if err != nil { + log.Println(err) + } + close(f.rep) } - - // run the query - err := srv.primary.Query(q.query, q.rep) - srv.mu.Unlock() - - if err != nil { - log.Println(err) - } - close(q.rep) } } @@ -168,9 +186,20 @@ func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus { func (c *ConnectionPool) Query(ctx context.Context, q string) (<-chan protocol.Packet, error) { rep := make(chan protocol.Packet) - c.queries <- query{ - query: q, - rep: rep, + c.queries <- request[string]{ + payload: q, + rep: rep, + } + + return rep, nil +} + +func (c *ConnectionPool) CallFunction(ctx context.Context, f *protocol.FunctionCall) (<-chan protocol.Packet, error) { + rep := make(chan protocol.Packet) + + c.functionCalls <- request[*protocol.FunctionCall]{ + payload: f, + rep: rep, } return rep, nil diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index 874da5ee7cd940d80cf8e6ad7a3708176b9c9a32..65711dd36907761dcb4993c1bf093fde7a92f605 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -194,6 +194,23 @@ func (s *Server) connect(ctx context.Context) error { } } +func (s *Server) forwardTo(rep chan<- protocol.Packet, predicate func(pkt protocol.Packet) (forward bool, finish bool)) error { + for { + var rsp protocol.Packet + rsp, err := protocol.ReadBackend(s.r) + if err != nil { + return err + } + forward, finish := predicate(rsp) + if forward { + rep <- rsp + } + if finish { + return nil + } + } +} + func (s *Server) Query(query string, rep chan<- protocol.Packet) error { // send to server q := new(protocol.Query) @@ -204,23 +221,34 @@ func (s *Server) Query(query string, rep chan<- protocol.Packet) error { } // read responses - for { - var rsp protocol.Packet - rsp, err = protocol.ReadBackend(s.r) - if err != nil { - return err - } - switch r := rsp.(type) { + return s.forwardTo(rep, func(pkt protocol.Packet) (forward bool, finish bool) { + switch r := pkt.(type) { case *protocol.ReadyForQuery: - if r.Fields.Status == 'I' { - rep <- rsp - return nil - } + return true, r.Fields.Status == 'I' case *protocol.CopyInResponse, *protocol.CopyOutResponse, *protocol.CopyBothResponse: - return fmt.Errorf("unsuported") + log.Println("client tried to enter copy mode") + return false, true + default: + return true, false } - rep <- rsp + }) +} + +func (s *Server) CallFunction(payload *protocol.FunctionCall, rep chan<- protocol.Packet) error { + _, err := payload.Write(s.wr) + if err != nil { + return err } + + // read responses + return s.forwardTo(rep, func(pkt protocol.Packet) (forward bool, finish bool) { + switch r := pkt.(type) { + case *protocol.ReadyForQuery: + return true, r.Fields.Status == 'I' + default: + return true, false + } + }) } func (s *Server) Close(ctx context.Context) error {