good morning!!!!

Skip to content
Snippets Groups Projects
Commit 60c108b1 authored by Garet Halliday's avatar Garet Halliday
Browse files

better query client

parent b4113311
Branches
Tags
No related merge requests found
......@@ -99,7 +99,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool {
var result authQueryResult
client := new(gsql.Client)
err := client.ExtendedQuery(&result, T.Config.PgBouncer.AuthQuery, user)
err := gsql.ExtendedQuery(client, &result, T.Config.PgBouncer.AuthQuery, user)
if err != nil {
log.Println("auth query failed:", err)
return nil
......
package gsql
import (
"crypto/tls"
"io"
"net"
"sync"
......@@ -9,100 +9,124 @@ import (
"pggat/lib/util/ring"
)
type batch struct {
result ResultWriter
packets []fed.Packet
}
type Client struct {
writeQ ring.Ring[ResultWriter]
writeC *sync.Cond
write ResultWriter
read ring.Ring[fed.Packet]
readQ ring.Ring[fed.Packet]
readC *sync.Cond
queue ring.Ring[batch]
closed bool
mu sync.Mutex
}
func (*Client) EnableSSLClient(_ *tls.Config) error {
panic("not implemented")
}
func (*Client) EnableSSLServer(_ *tls.Config) error {
panic("not implemented")
readQueue chan struct{}
writeQueue chan struct{}
}
func (*Client) ReadByte() (byte, error) {
panic("not implemented")
}
func (T *Client) Do(result ResultWriter, packets ...fed.Packet) {
T.mu.Lock()
defer T.mu.Unlock()
func (T *Client) queuePackets(packets ...fed.Packet) {
for _, packet := range packets {
T.readQ.PushBack(packet)
T.queue.PushBack(batch{
result: result,
packets: packets,
})
if T.readC != nil {
T.readC.Signal()
if T.readQueue != nil {
for {
select {
case T.readQueue <- struct{}{}:
default:
return
}
}
}
func (T *Client) queueResults(results ...ResultWriter) {
for _, result := range results {
T.writeQ.PushBack(result)
if T.writeC != nil {
T.writeC.Signal()
}
}
}
func (T *Client) ReadPacket(typed bool) (fed.Packet, error) {
T.mu.Lock()
defer T.mu.Unlock()
p, ok := T.readQ.PopFront()
for !ok {
var p fed.Packet
for {
var ok bool
p, ok = T.read.PopFront()
if ok {
break
}
// try to add next in queue
b, ok := T.queue.PopFront()
if ok {
for _, packet := range b.packets {
T.read.PushBack(packet)
}
T.write = b.result
outer:
for {
select {
case T.writeQueue <- struct{}{}:
default:
break outer
}
}
continue
}
if T.closed {
return nil, net.ErrClosed
return nil, io.EOF
}
if T.readC == nil {
T.readC = sync.NewCond(&T.mu)
func() {
if T.readQueue == nil {
T.readQueue = make(chan struct{})
}
T.readC.Wait()
p, ok = T.readQ.PopFront()
q := T.readQueue
T.mu.Unlock()
defer T.mu.Lock()
<-q
}()
}
if (p.Type() == 0 && typed) || (p.Type() != 0 && !typed) {
panic("tried to read typed as untyped or untyped as typed")
return nil, ErrTypedMismatch
}
return p, nil
}
func (*Client) WriteByte(_ byte) error {
panic("not implemented")
}
func (T *Client) WritePacket(packet fed.Packet) error {
if T.write == nil {
T.write, _ = T.writeQ.PopFront()
T.mu.Lock()
defer T.mu.Unlock()
for T.write == nil {
if T.closed {
return net.ErrClosed
return io.EOF
}
if T.writeC == nil {
T.writeC = sync.NewCond(&T.mu)
}
T.writeC.Wait()
T.write, _ = T.writeQ.PopFront()
func() {
if T.writeQueue == nil {
T.writeQueue = make(chan struct{})
}
q := T.writeQueue
T.mu.Unlock()
defer T.mu.Lock()
<-q
}()
}
if err := T.write.WritePacket(packet); err != nil {
return err
}
if T.write.Done() {
T.write = nil
}
return nil
}
......@@ -115,6 +139,13 @@ func (T *Client) Close() error {
}
T.closed = true
if T.writeQueue != nil {
close(T.writeQueue)
}
if T.readQueue != nil {
close(T.readQueue)
}
return nil
}
......
......@@ -8,20 +8,19 @@ import (
packets "pggat/lib/fed/packets/v3.0"
)
func (T *Client) ExtendedQuery(result any, query string, args ...any) error {
func ExtendedQuery(client *Client, result any, query string, args ...any) error {
if len(args) == 0 {
T.Query(query, result)
Query(client, []any{result}, query)
return nil
}
T.mu.Lock()
defer T.mu.Unlock()
var pkts []fed.Packet
// parse
parse := packets.Parse{
Query: query,
}
T.queuePackets(parse.IntoPacket())
pkts = append(pkts, parse.IntoPacket())
// bind
params := make([][]byte, 0, len(args))
......@@ -61,23 +60,23 @@ outer:
bind := packets.Bind{
ParameterValues: params,
}
T.queuePackets(bind.IntoPacket())
pkts = append(pkts, bind.IntoPacket())
// describe
describe := packets.Describe{
Which: 'P',
}
T.queuePackets(describe.IntoPacket())
pkts = append(pkts, describe.IntoPacket())
// execute
execute := packets.Execute{}
T.queuePackets(execute.IntoPacket())
pkts = append(pkts, execute.IntoPacket())
// sync
sync := fed.NewPacket(packets.TypeSync)
T.queuePackets(sync)
pkts = append(pkts, sync)
// result
T.queueResults(NewQueryWriter(result))
client.Do(NewQueryWriter(result), pkts...)
return nil
}
......@@ -7,4 +7,5 @@ var (
ErrExtraFields = errors.New("received unexpected fields")
ErrResultMustBeNonNil = errors.New("result must be non nil")
ErrUnexpectedType = errors.New("unexpected result type")
ErrTypedMismatch = errors.New("tried to read typed packet as untyped or untyped packet as typed")
)
......@@ -5,20 +5,15 @@ import (
packets "pggat/lib/fed/packets/v3.0"
)
func (T *Client) Query(query string, results ...any) {
T.mu.Lock()
defer T.mu.Unlock()
func Query(client *Client, results []any, query string) {
var q = packets.Query(query)
T.queueResults(NewQueryWriter(results...))
T.queuePackets(q.IntoPacket())
client.Do(NewQueryWriter(results...), q.IntoPacket())
}
type QueryWriter struct {
writers []RowWriter
writerNum int
done bool
}
func NewQueryWriter(results ...any) *QueryWriter {
......@@ -33,11 +28,6 @@ func NewQueryWriter(results ...any) *QueryWriter {
}
func (T *QueryWriter) WritePacket(packet fed.Packet) error {
if packet.Type() == packets.TypeReadyForQuery {
T.done = true
return nil
}
if T.writerNum >= len(T.writers) {
// ignore
return nil
......@@ -55,8 +45,4 @@ func (T *QueryWriter) WritePacket(packet fed.Packet) error {
return nil
}
func (T *QueryWriter) Done() bool {
return T.done
}
var _ ResultWriter = (*QueryWriter)(nil)
......@@ -39,7 +39,7 @@ func TestQuery(t *testing.T) {
var res Result
client := new(Client)
err = client.ExtendedQuery(&res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "bob")
err = ExtendedQuery(client, &res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "bob")
if err != nil {
t.Error(err)
return
......
......@@ -3,6 +3,5 @@ package gsql
import "pggat/lib/fed"
type ResultWriter interface {
WritePacket(packet fed.Packet) error
Done() bool
fed.Writer
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment