good morning!!!!

Skip to content
Snippets Groups Projects
Commit fc6f884a authored by Garet Halliday's avatar Garet Halliday
Browse files

a

parent 69b1d7ee
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import (
func Query(server zap.ReadWriter, query string) error {
packet := zap.NewPacket()
defer packet.Done()
packet.WriteType(packets.Query)
packet.WriteString(query)
err := server.Write(packet)
......
......@@ -12,6 +12,7 @@ import (
"pggat2/lib/util/chans"
"pggat2/lib/util/maps"
"pggat2/lib/util/ring"
"pggat2/lib/util/strings"
"pggat2/lib/zap"
packets "pggat2/lib/zap/packets/v3.0"
)
......@@ -68,11 +69,19 @@ func (T *Pool) _release(id uuid.UUID) {
T.ready.Signal()
}
func (T *Pool) close(conn Conn) {
_ = conn.rw.Close()
T.qmu.Lock()
defer T.qmu.Unlock()
delete(T.conns, conn.id)
}
func (T *Pool) release(conn Conn) {
// reset session state
err := backends.Query(conn.rw, "DISCARD ALL")
if err != nil {
_ = conn.rw.Close()
T.close(conn)
return
}
......@@ -82,47 +91,57 @@ func (T *Pool) release(conn Conn) {
}
func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, startupParameters map[string]string) {
defer func() {
_ = client.Close()
}()
connOk := true
conn := T.acquire(ctx)
defer func() {
if connOk {
T.release(conn)
} else {
T.close(conn)
}
}()
if func() bool {
pkts := zap.NewPackets()
defer pkts.Done()
for key, value := range conn.initialParameters {
if _, ok := startupParameters[key]; ok {
continue
}
packet := zap.NewPacket()
packets.WriteParameterStatus(packet, key, value)
pkts.Append(packet)
}
pkts := zap.NewPackets()
for key, value := range conn.initialParameters {
if _, ok := startupParameters[key]; ok {
continue
for key, value := range startupParameters {
err := backends.Query(conn.rw, "SET "+key+" = '"+strings.Escape(value, "'")+"'")
if err != nil {
connOk = false
return true
}
packet := zap.NewPacket()
packets.WriteParameterStatus(packet, key, value)
pkts.Append(packet)
}
packet := zap.NewPacket()
packets.WriteParameterStatus(packet, key, value)
pkts.Append(packet)
}
err := client.WriteV(pkts)
if err != nil {
pkts.Done()
_ = client.Close()
T.release(conn)
return
}
pkts.Done()
for key, value := range startupParameters {
err = backends.Query(conn.rw, "SET "+key+" = '"+value+"'")
err := client.WriteV(pkts)
if err != nil {
_ = client.Close()
_ = conn.rw.Close()
return
return true
}
return false
}() {
return
}
for {
clientErr, serverErr := bouncers.Bounce(client, conn.rw)
if clientErr != nil || serverErr != nil {
_ = client.Close()
if serverErr == nil {
T.release(conn)
} else {
_ = conn.rw.Close()
T.qmu.Lock()
delete(T.conns, conn.id)
T.qmu.Unlock()
}
connOk = serverErr == nil
break
}
}
......
package strings
import (
"strings"
"unicode/utf8"
)
func Escape(str, sequence string) string {
var b strings.Builder
for len(str) > 0 {
if strings.HasPrefix(str, sequence) {
b.WriteByte('\\')
b.WriteString(sequence)
str = str[len(sequence):]
continue
}
if strings.HasPrefix(str, "\\") {
b.WriteString("\\\\")
str = str[1:]
continue
}
r, size := utf8.DecodeRuneInString(str)
if r == utf8.RuneError {
return ""
}
b.WriteRune(r)
str = str[size:]
}
return b.String()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment