diff --git a/lib/gat/modes/pgbouncer/pools.go b/lib/gat/modes/pgbouncer/pools.go index 52e06d0778bef864348d8b0926a52bb2b83265eb..c6d0585d6d731d74d285534aa321b35390e2cf97 100644 --- a/lib/gat/modes/pgbouncer/pools.go +++ b/lib/gat/modes/pgbouncer/pools.go @@ -99,7 +99,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { var result authQueryResult client := new(gsql.Client) - err := client.ExtendedQuery(&result, T.Config.PgBouncer.AuthQuery, user) + err := gsql.ExtendedQuery(client, &result, T.Config.PgBouncer.AuthQuery, user) if err != nil { log.Println("auth query failed:", err) return nil diff --git a/lib/gsql/client.go b/lib/gsql/client.go index 0e672ba86eb84b10dd89195a7a5d609420aea545..07df6795345361bcbaa5db6b9823253a13ba1213 100644 --- a/lib/gsql/client.go +++ b/lib/gsql/client.go @@ -1,7 +1,7 @@ package gsql import ( - "crypto/tls" + "io" "net" "sync" @@ -9,46 +9,40 @@ import ( "pggat/lib/util/ring" ) +type batch struct { + result ResultWriter + packets []fed.Packet +} + type Client struct { - writeQ ring.Ring[ResultWriter] - writeC *sync.Cond - write ResultWriter + write ResultWriter + read ring.Ring[fed.Packet] - readQ ring.Ring[fed.Packet] - readC *sync.Cond + queue ring.Ring[batch] closed bool mu sync.Mutex -} -func (*Client) EnableSSLClient(_ *tls.Config) error { - panic("not implemented") -} - -func (*Client) EnableSSLServer(_ *tls.Config) error { - panic("not implemented") -} - -func (*Client) ReadByte() (byte, error) { - panic("not implemented") -} - -func (T *Client) queuePackets(packets ...fed.Packet) { - for _, packet := range packets { - T.readQ.PushBack(packet) - - if T.readC != nil { - T.readC.Signal() - } - } + readQueue chan struct{} + writeQueue chan struct{} } -func (T *Client) queueResults(results ...ResultWriter) { - for _, result := range results { - T.writeQ.PushBack(result) +func (T *Client) Do(result ResultWriter, packets ...fed.Packet) { + T.mu.Lock() + defer T.mu.Unlock() - if T.writeC != nil { - T.writeC.Signal() + T.queue.PushBack(batch{ + result: result, + packets: packets, + }) + + if T.readQueue != nil { + for { + select { + case T.readQueue <- struct{}{}: + default: + return + } } } } @@ -57,52 +51,82 @@ func (T *Client) ReadPacket(typed bool) (fed.Packet, error) { T.mu.Lock() defer T.mu.Unlock() - p, ok := T.readQ.PopFront() - for !ok { - if T.closed { - return nil, net.ErrClosed + var p fed.Packet + for { + var ok bool + p, ok = T.read.PopFront() + if ok { + break } - if T.readC == nil { - T.readC = sync.NewCond(&T.mu) + + // try to add next in queue + b, ok := T.queue.PopFront() + if ok { + for _, packet := range b.packets { + T.read.PushBack(packet) + } + T.write = b.result + outer: + for { + select { + case T.writeQueue <- struct{}{}: + default: + break outer + } + } + continue + } + + if T.closed { + return nil, io.EOF } - T.readC.Wait() - p, ok = T.readQ.PopFront() + + func() { + if T.readQueue == nil { + T.readQueue = make(chan struct{}) + } + q := T.readQueue + + T.mu.Unlock() + defer T.mu.Lock() + + <-q + }() } if (p.Type() == 0 && typed) || (p.Type() != 0 && !typed) { - panic("tried to read typed as untyped or untyped as typed") + return nil, ErrTypedMismatch } return p, nil } -func (*Client) WriteByte(_ byte) error { - panic("not implemented") -} - func (T *Client) WritePacket(packet fed.Packet) error { - if T.write == nil { - T.write, _ = T.writeQ.PopFront() - for T.write == nil { - if T.closed { - return net.ErrClosed - } - if T.writeC == nil { - T.writeC = sync.NewCond(&T.mu) - } - T.writeC.Wait() - T.write, _ = T.writeQ.PopFront() + T.mu.Lock() + defer T.mu.Unlock() + + for T.write == nil { + if T.closed { + return io.EOF } + + func() { + if T.writeQueue == nil { + T.writeQueue = make(chan struct{}) + } + q := T.writeQueue + + T.mu.Unlock() + defer T.mu.Lock() + + <-q + }() } if err := T.write.WritePacket(packet); err != nil { return err } - if T.write.Done() { - T.write = nil - } - return nil } @@ -115,6 +139,13 @@ func (T *Client) Close() error { } T.closed = true + + if T.writeQueue != nil { + close(T.writeQueue) + } + if T.readQueue != nil { + close(T.readQueue) + } return nil } diff --git a/lib/gsql/eq.go b/lib/gsql/eq.go index 46100e074ca7c79040e4d6184fc3cd518931ac20..f03cfe1d92d5c10f821d0f5457e16c38d4b5f1c7 100644 --- a/lib/gsql/eq.go +++ b/lib/gsql/eq.go @@ -8,20 +8,19 @@ import ( packets "pggat/lib/fed/packets/v3.0" ) -func (T *Client) ExtendedQuery(result any, query string, args ...any) error { +func ExtendedQuery(client *Client, result any, query string, args ...any) error { if len(args) == 0 { - T.Query(query, result) + Query(client, []any{result}, query) return nil } - T.mu.Lock() - defer T.mu.Unlock() + var pkts []fed.Packet // parse parse := packets.Parse{ Query: query, } - T.queuePackets(parse.IntoPacket()) + pkts = append(pkts, parse.IntoPacket()) // bind params := make([][]byte, 0, len(args)) @@ -61,23 +60,23 @@ outer: bind := packets.Bind{ ParameterValues: params, } - T.queuePackets(bind.IntoPacket()) + pkts = append(pkts, bind.IntoPacket()) // describe describe := packets.Describe{ Which: 'P', } - T.queuePackets(describe.IntoPacket()) + pkts = append(pkts, describe.IntoPacket()) // execute execute := packets.Execute{} - T.queuePackets(execute.IntoPacket()) + pkts = append(pkts, execute.IntoPacket()) // sync sync := fed.NewPacket(packets.TypeSync) - T.queuePackets(sync) + pkts = append(pkts, sync) // result - T.queueResults(NewQueryWriter(result)) + client.Do(NewQueryWriter(result), pkts...) return nil } diff --git a/lib/gsql/errors.go b/lib/gsql/errors.go index 93ad0d836db2d94ab39ca91ab9b9375ee79762d0..a550f7496561030d181eedbcbc8ba281007e2d61 100644 --- a/lib/gsql/errors.go +++ b/lib/gsql/errors.go @@ -7,4 +7,5 @@ var ( ErrExtraFields = errors.New("received unexpected fields") ErrResultMustBeNonNil = errors.New("result must be non nil") ErrUnexpectedType = errors.New("unexpected result type") + ErrTypedMismatch = errors.New("tried to read typed packet as untyped or untyped packet as typed") ) diff --git a/lib/gsql/query.go b/lib/gsql/query.go index a2eb1e1d8ae3e4dbe2fcf5d3d0821f0d9bbbf23b..518b55869ae67508d3b72dd60160cf3dc0e996c6 100644 --- a/lib/gsql/query.go +++ b/lib/gsql/query.go @@ -5,20 +5,15 @@ import ( packets "pggat/lib/fed/packets/v3.0" ) -func (T *Client) Query(query string, results ...any) { - T.mu.Lock() - defer T.mu.Unlock() - +func Query(client *Client, results []any, query string) { var q = packets.Query(query) - T.queueResults(NewQueryWriter(results...)) - T.queuePackets(q.IntoPacket()) + client.Do(NewQueryWriter(results...), q.IntoPacket()) } type QueryWriter struct { writers []RowWriter writerNum int - done bool } func NewQueryWriter(results ...any) *QueryWriter { @@ -33,11 +28,6 @@ func NewQueryWriter(results ...any) *QueryWriter { } func (T *QueryWriter) WritePacket(packet fed.Packet) error { - if packet.Type() == packets.TypeReadyForQuery { - T.done = true - return nil - } - if T.writerNum >= len(T.writers) { // ignore return nil @@ -55,8 +45,4 @@ func (T *QueryWriter) WritePacket(packet fed.Packet) error { return nil } -func (T *QueryWriter) Done() bool { - return T.done -} - var _ ResultWriter = (*QueryWriter)(nil) diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index f77c56986577f9bfacba1ce24b1e11003502f143..9de3ed4d0e795c572edd089905d14e29fdd7752a 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -39,7 +39,7 @@ func TestQuery(t *testing.T) { var res Result client := new(Client) - err = client.ExtendedQuery(&res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "bob") + err = ExtendedQuery(client, &res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "bob") if err != nil { t.Error(err) return diff --git a/lib/gsql/result.go b/lib/gsql/result.go index ccb7c040b6483b391c2f4be70aecb44def59875d..e3f9f79d6bebddc1b4a39543f6f34049c3519ed2 100644 --- a/lib/gsql/result.go +++ b/lib/gsql/result.go @@ -3,6 +3,5 @@ package gsql import "pggat/lib/fed" type ResultWriter interface { - WritePacket(packet fed.Packet) error - Done() bool + fed.Writer }