good morning!!!!

Skip to content
Snippets Groups Projects
Commit 80999fc8 authored by Garet Halliday's avatar Garet Halliday
Browse files

function calls

parent 4a9fb33a
No related branches found
No related tags found
No related merge requests found
......@@ -10,4 +10,5 @@ type ConnectionPool interface {
GetUser() *config.User
GetServerInfo() []*protocol.ParameterStatus
Query(ctx context.Context, query string) (<-chan protocol.Packet, error)
CallFunction(ctx context.Context, payload *protocol.FunctionCall) (<-chan protocol.Packet, error)
}
......@@ -292,6 +292,8 @@ func (c *Client) tick(ctx context.Context) (bool, error) {
switch cast := rsp.(type) {
case *protocol.Query:
return true, c.handle_query(ctx, cast)
case *protocol.FunctionCall:
return true, c.handle_function(ctx, cast)
case *protocol.Terminate:
return false, nil
default:
......@@ -299,22 +301,33 @@ func (c *Client) tick(ctx context.Context) (bool, error) {
return true, nil
}
func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error {
rep, err := c.server.Query(ctx, q.Fields.Query)
if err != nil {
return err
}
func (c *Client) forward(pkts <-chan protocol.Packet) error {
for {
rsp := <-rep
rsp := <-pkts
if rsp == nil {
break
return nil
}
err = c.Send(rsp)
err := c.Send(rsp)
if err != nil {
return err
}
}
return nil
}
func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error {
rep, err := c.server.Query(ctx, q.Fields.Query)
if err != nil {
return err
}
return c.forward(rep)
}
func (c *Client) handle_function(ctx context.Context, f *protocol.FunctionCall) error {
rep, err := c.server.CallFunction(ctx, f)
if err != nil {
return err
}
return c.forward(rep)
}
/*
......
......@@ -14,9 +14,9 @@ import (
"sync"
)
type query struct {
query string
rep chan<- protocol.Packet
type request[T any] struct {
payload T
rep chan<- protocol.Packet
}
type servers struct {
......@@ -34,20 +34,22 @@ type shard struct {
}
type ConnectionPool struct {
c *config.Pool
user *config.User
pool gat.Pool
shards []shard
queries chan query
c *config.Pool
user *config.User
pool gat.Pool
shards []shard
queries chan request[string]
functionCalls chan request[*protocol.FunctionCall]
mu sync.RWMutex
}
func NewConnectionPool(pool gat.Pool, conf *config.Pool, user *config.User) *ConnectionPool {
p := &ConnectionPool{
user: user,
pool: pool,
queries: make(chan query),
user: user,
pool: pool,
queries: make(chan request[string]),
functionCalls: make(chan request[*protocol.FunctionCall]),
}
p.EnsureConfig(conf)
for i := 0; i < user.PoolSize; i++ {
......@@ -133,22 +135,38 @@ func (c *ConnectionPool) chooseServer(query string) *servers {
func (c *ConnectionPool) worker() {
for {
q := <-c.queries
srv := c.chooseServer(q.query)
if srv == nil {
log.Printf("call to query '%s' failed", q.query)
continue
select {
case q := <-c.queries:
srv := c.chooseServer(q.payload)
if srv == nil {
log.Printf("call to query '%s' failed", q.payload)
continue
}
// run the query
err := srv.primary.Query(q.payload, q.rep)
srv.mu.Unlock()
if err != nil {
log.Println(err)
}
close(q.rep)
case f := <-c.functionCalls:
srv := c.chooseServer("")
if srv == nil {
log.Printf("function call '%+v' failed", f.payload)
continue
}
// run the query
err := srv.primary.CallFunction(f.payload, f.rep)
srv.mu.Unlock()
if err != nil {
log.Println(err)
}
close(f.rep)
}
// run the query
err := srv.primary.Query(q.query, q.rep)
srv.mu.Unlock()
if err != nil {
log.Println(err)
}
close(q.rep)
}
}
......@@ -168,9 +186,20 @@ func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus {
func (c *ConnectionPool) Query(ctx context.Context, q string) (<-chan protocol.Packet, error) {
rep := make(chan protocol.Packet)
c.queries <- query{
query: q,
rep: rep,
c.queries <- request[string]{
payload: q,
rep: rep,
}
return rep, nil
}
func (c *ConnectionPool) CallFunction(ctx context.Context, f *protocol.FunctionCall) (<-chan protocol.Packet, error) {
rep := make(chan protocol.Packet)
c.functionCalls <- request[*protocol.FunctionCall]{
payload: f,
rep: rep,
}
return rep, nil
......
......@@ -194,6 +194,23 @@ func (s *Server) connect(ctx context.Context) error {
}
}
func (s *Server) forwardTo(rep chan<- protocol.Packet, predicate func(pkt protocol.Packet) (forward bool, finish bool)) error {
for {
var rsp protocol.Packet
rsp, err := protocol.ReadBackend(s.r)
if err != nil {
return err
}
forward, finish := predicate(rsp)
if forward {
rep <- rsp
}
if finish {
return nil
}
}
}
func (s *Server) Query(query string, rep chan<- protocol.Packet) error {
// send to server
q := new(protocol.Query)
......@@ -204,23 +221,34 @@ func (s *Server) Query(query string, rep chan<- protocol.Packet) error {
}
// read responses
for {
var rsp protocol.Packet
rsp, err = protocol.ReadBackend(s.r)
if err != nil {
return err
}
switch r := rsp.(type) {
return s.forwardTo(rep, func(pkt protocol.Packet) (forward bool, finish bool) {
switch r := pkt.(type) {
case *protocol.ReadyForQuery:
if r.Fields.Status == 'I' {
rep <- rsp
return nil
}
return true, r.Fields.Status == 'I'
case *protocol.CopyInResponse, *protocol.CopyOutResponse, *protocol.CopyBothResponse:
return fmt.Errorf("unsuported")
log.Println("client tried to enter copy mode")
return false, true
default:
return true, false
}
rep <- rsp
})
}
func (s *Server) CallFunction(payload *protocol.FunctionCall, rep chan<- protocol.Packet) error {
_, err := payload.Write(s.wr)
if err != nil {
return err
}
// read responses
return s.forwardTo(rep, func(pkt protocol.Packet) (forward bool, finish bool) {
switch r := pkt.(type) {
case *protocol.ReadyForQuery:
return true, r.Fields.Status == 'I'
default:
return true, false
}
})
}
func (s *Server) Close(ctx context.Context) error {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment