good morning!!!!

Skip to content
Snippets Groups Projects
Commit 78cc2f41 authored by Garet Halliday's avatar Garet Halliday
Browse files

i hope the programming gods can forgive me after that one

parent abe6bfcb
Branches
Tags
No related merge requests found
...@@ -22,13 +22,14 @@ import ( ...@@ -22,13 +22,14 @@ import (
"net" "net"
"reflect" "reflect"
"strings" "strings"
"sync"
) )
// / client state, one per client // / client state, one per client
type Client struct { type Client struct {
conn net.Conn conn net.Conn
r *bufio.Reader r *bufio.Reader
wr io.Writer wr *bufio.Writer
recv chan protocol.Packet recv chan protocol.Packet
...@@ -62,6 +63,8 @@ type Client struct { ...@@ -62,6 +63,8 @@ type Client struct {
state rune state rune
log zlog.Logger log zlog.Logger
mu sync.Mutex
} }
func NewClient( func NewClient(
...@@ -76,7 +79,7 @@ func NewClient( ...@@ -76,7 +79,7 @@ func NewClient(
c := &Client{ c := &Client{
conn: conn, conn: conn,
r: bufio.NewReader(conn), r: bufio.NewReader(conn),
wr: conn, wr: bufio.NewWriter(conn),
recv: make(chan protocol.Packet), recv: make(chan protocol.Packet),
addr: conn.RemoteAddr(), addr: conn.RemoteAddr(),
pid: int32(pid.Int64()), pid: int32(pid.Int64()),
...@@ -129,6 +132,10 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -129,6 +132,10 @@ func (c *Client) Accept(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
err = c.wr.Flush()
if err != nil {
return err
}
startup = new(protocol.StartupMessage) startup = new(protocol.StartupMessage)
err = startup.Read(c.r) err = startup.Read(c.r)
if err != nil { if err != nil {
...@@ -139,6 +146,10 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -139,6 +146,10 @@ func (c *Client) Accept(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
err = c.wr.Flush()
if err != nil {
return err
}
//TODO: we need to do an ssl handshake here. //TODO: we need to do an ssl handshake here.
var cert tls.Certificate var cert tls.Certificate
cert, err = tls.LoadX509KeyPair(c.conf.General.TlsCertificate, c.conf.General.TlsPrivateKey) cert, err = tls.LoadX509KeyPair(c.conf.General.TlsCertificate, c.conf.General.TlsPrivateKey)
...@@ -151,7 +162,7 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -151,7 +162,7 @@ func (c *Client) Accept(ctx context.Context) error {
} }
c.conn = tls.Server(c.conn, cfg) c.conn = tls.Server(c.conn, cfg)
c.r = bufio.NewReader(c.conn) c.r = bufio.NewReader(c.conn)
c.wr = c.conn c.wr = bufio.NewWriter(c.conn)
err = startup.Read(c.r) err = startup.Read(c.r)
if err != nil { if err != nil {
return err return err
...@@ -200,7 +211,11 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -200,7 +211,11 @@ func (c *Client) Accept(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
_, err = pkt.Write(c.wr) err = c.Send(pkt)
if err != nil {
return err
}
err = c.Flush()
if err != nil { if err != nil {
return err return err
} }
...@@ -263,7 +278,7 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -263,7 +278,7 @@ func (c *Client) Accept(ctx context.Context) error {
authOk := new(protocol.Authentication) authOk := new(protocol.Authentication)
authOk.Fields.Code = 0 authOk.Fields.Code = 0
_, err = authOk.Write(c.wr) err = c.Send(authOk)
if err != nil { if err != nil {
return err return err
} }
...@@ -271,7 +286,7 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -271,7 +286,7 @@ func (c *Client) Accept(ctx context.Context) error {
// //
info := c.server.GetServerInfo() info := c.server.GetServerInfo()
for _, inf := range info { for _, inf := range info {
_, err = inf.Write(c.wr) err = c.Send(inf)
if err != nil { if err != nil {
return err return err
} }
...@@ -279,19 +294,23 @@ func (c *Client) Accept(ctx context.Context) error { ...@@ -279,19 +294,23 @@ func (c *Client) Accept(ctx context.Context) error {
backendKeyData := new(protocol.BackendKeyData) backendKeyData := new(protocol.BackendKeyData)
backendKeyData.Fields.ProcessID = c.pid backendKeyData.Fields.ProcessID = c.pid
backendKeyData.Fields.SecretKey = c.secret_key backendKeyData.Fields.SecretKey = c.secret_key
_, err = backendKeyData.Write(c.wr) err = c.Send(backendKeyData)
if err != nil { if err != nil {
return err return err
} }
readyForQuery := new(protocol.ReadyForQuery) readyForQuery := new(protocol.ReadyForQuery)
readyForQuery.Fields.Status = byte('I') readyForQuery.Fields.Status = byte('I')
_, err = readyForQuery.Write(c.wr) err = c.Send(readyForQuery)
if err != nil { if err != nil {
return err return err
} }
go c.recvLoop() go c.recvLoop()
open := true open := true
for open { for open {
err = c.Flush()
if err != nil {
return err
}
open, err = c.tick(ctx) open, err = c.tick(ctx)
if !open { if !open {
break break
...@@ -348,7 +367,7 @@ func (c *Client) handle_cancel(ctx context.Context, p *protocol.StartupMessage) ...@@ -348,7 +367,7 @@ func (c *Client) handle_cancel(ctx context.Context, p *protocol.StartupMessage)
func (c *Client) tick(ctx context.Context) (bool, error) { func (c *Client) tick(ctx context.Context) (bool, error) {
var rsp protocol.Packet var rsp protocol.Packet
select { select {
case rsp = <-c.Recv(): case rsp = <-c.recv:
case <-ctx.Done(): case <-ctx.Done():
return false, ctx.Err() return false, ctx.Err()
} }
...@@ -492,11 +511,19 @@ func (c *Client) GetPortal(name string) *protocol.Bind { ...@@ -492,11 +511,19 @@ func (c *Client) GetPortal(name string) *protocol.Bind {
} }
func (c *Client) Send(pkt protocol.Packet) error { func (c *Client) Send(pkt protocol.Packet) error {
c.mu.Lock()
defer c.mu.Unlock()
//log.Printf("sent packet(%s) %+v", reflect.TypeOf(pkt), pkt) //log.Printf("sent packet(%s) %+v", reflect.TypeOf(pkt), pkt)
_, err := pkt.Write(c.wr) _, err := pkt.Write(c.wr)
return err return err
} }
func (c *Client) Flush() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.wr.Flush()
}
func (c *Client) Recv() <-chan protocol.Packet { func (c *Client) Recv() <-chan protocol.Packet {
return c.recv return c.recv
} }
......
...@@ -3,7 +3,6 @@ package server ...@@ -3,7 +3,6 @@ package server
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"io"
"net" "net"
"reflect" "reflect"
"time" "time"
...@@ -28,7 +27,7 @@ type Server struct { ...@@ -28,7 +27,7 @@ type Server struct {
remote net.Addr remote net.Addr
conn net.Conn conn net.Conn
r *bufio.Reader r *bufio.Reader
wr io.Writer wr *bufio.Writer
server_info []*protocol.ParameterStatus server_info []*protocol.ParameterStatus
...@@ -72,13 +71,9 @@ func Dial(ctx context.Context, ...@@ -72,13 +71,9 @@ func Dial(ctx context.Context,
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = s.conn.(*net.TCPConn).SetNoDelay(false)
if err != nil {
return nil, err
}
s.remote = s.conn.RemoteAddr() s.remote = s.conn.RemoteAddr()
s.r = bufio.NewReader(s.conn) s.r = bufio.NewReader(s.conn)
s.wr = s.conn s.wr = bufio.NewWriter(s.conn)
s.user = *user s.user = *user
s.db = db s.db = db
...@@ -127,7 +122,7 @@ func (s *Server) startup(ctx context.Context) error { ...@@ -127,7 +122,7 @@ func (s *Server) startup(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
return nil return s.flush()
} }
func (s *Server) connect(ctx context.Context) error { func (s *Server) connect(ctx context.Context) error {
...@@ -164,17 +159,20 @@ func (s *Server) connect(ctx context.Context) error { ...@@ -164,17 +159,20 @@ func (s *Server) connect(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
func() {
rsp := new(protocol.AuthenticationResponse) rsp := new(protocol.AuthenticationResponse)
buf := bufpool.Get(len(scrm.Name()) + 1 + 4 + len(bts)) buf := bufpool.Get(len(scrm.Name()) + 1 + 4 + len(bts))
buf.Reset() buf.Reset()
defer bufpool.Put(buf)
_, _ = protocol.WriteString(buf, scrm.Name()) _, _ = protocol.WriteString(buf, scrm.Name())
_, _ = protocol.WriteInt32(buf, int32(len(bts))) _, _ = protocol.WriteInt32(buf, int32(len(bts)))
buf.Write(bts) buf.Write(bts)
rsp.Fields.Data = buf.Bytes() rsp.Fields.Data = buf.Bytes()
err = s.writePacket(rsp) err = s.writePacket(rsp)
}() bufpool.Put(buf)
if err != nil {
return err
}
err = s.flush()
if err != nil { if err != nil {
return err return err
} }
...@@ -189,6 +187,10 @@ func (s *Server) connect(ctx context.Context) error { ...@@ -189,6 +187,10 @@ func (s *Server) connect(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
err = s.flush()
if err != nil {
return err
}
case 12: // SASL_FINAL case 12: // SASL_FINAL
s.log.Debug().Str("method", "scram256").Msg("sasl final") s.log.Debug().Str("method", "scram256").Msg("sasl final")
var done bool var done bool
...@@ -244,6 +246,10 @@ func (s *Server) writePacket(pkt protocol.Packet) error { ...@@ -244,6 +246,10 @@ func (s *Server) writePacket(pkt protocol.Packet) error {
return err return err
} }
func (s *Server) flush() error {
return s.wr.Flush()
}
func (s *Server) readPacket() (protocol.Packet, error) { func (s *Server) readPacket() (protocol.Packet, error) {
return protocol.ReadBackend(s.r) return protocol.ReadBackend(s.r)
} }
...@@ -311,6 +317,7 @@ func (s *Server) destructPreparedStatement(name string) { ...@@ -311,6 +317,7 @@ func (s *Server) destructPreparedStatement(name string) {
query := new(protocol.Query) query := new(protocol.Query)
query.Fields.Query = fmt.Sprintf("DEALLOCATE \"%s\"", name) query.Fields.Query = fmt.Sprintf("DEALLOCATE \"%s\"", name)
_ = s.writePacket(query) _ = s.writePacket(query)
_ = s.flush()
// await server ready // await server ready
for { for {
r, _ := s.readPacket() r, _ := s.readPacket()
...@@ -358,6 +365,10 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error { ...@@ -358,6 +365,10 @@ func (s *Server) Describe(client gat.Client, d *protocol.Describe) error {
if err != nil { if err != nil {
return err return err
} }
err = s.flush()
if err != nil {
return err
}
return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
//log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt)
...@@ -386,6 +397,10 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error { ...@@ -386,6 +397,10 @@ func (s *Server) Execute(client gat.Client, e *protocol.Execute) error {
if err != nil { if err != nil {
return err return err
} }
err = s.flush()
if err != nil {
return err
}
return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
//log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt) //log.Println("forward packet(%s) %+v", reflect.TypeOf(pkt), pkt)
...@@ -408,6 +423,10 @@ func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query strin ...@@ -408,6 +423,10 @@ func (s *Server) SimpleQuery(ctx context.Context, client gat.Client, query strin
if err != nil { if err != nil {
return err return err
} }
err = s.flush()
if err != nil {
return err
}
// this function seems wild but it has to be the way it is so we read the whole response, even if the // this function seems wild but it has to be the way it is so we read the whole response, even if the
// client fails midway // client fails midway
...@@ -435,6 +454,10 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin ...@@ -435,6 +454,10 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin
if err != nil { if err != nil {
return err return err
} }
err = s.flush()
if err != nil {
return err
}
return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
//log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt) //log.Printf("got server pkt pkt(%s): %+v ", reflect.TypeOf(pkt), pkt)
switch p := pkt.(type) { switch p := pkt.(type) {
...@@ -444,6 +467,8 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin ...@@ -444,6 +467,8 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin
// send to client and wait for next query // send to client and wait for next query
err = client.Send(pkt) err = client.Send(pkt)
if err == nil {
err = client.Flush()
if err == nil { if err == nil {
select { select {
case r := <-client.Recv(): case r := <-client.Recv():
...@@ -452,6 +477,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin ...@@ -452,6 +477,7 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin
case *protocol.Query: case *protocol.Query:
//forward to server //forward to server
_ = s.writePacket(r) _ = s.writePacket(r)
_ = s.flush()
default: default:
err = fmt.Errorf("expected an error in transaction state but got something else") err = fmt.Errorf("expected an error in transaction state but got something else")
} }
...@@ -459,11 +485,13 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin ...@@ -459,11 +485,13 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin
err = ctx.Err() err = ctx.Err()
} }
} }
}
if err != nil { if err != nil {
end := new(protocol.Query) end := new(protocol.Query)
end.Fields.Query = "END;" end.Fields.Query = "END;"
_ = s.writePacket(end) _ = s.writePacket(end)
_ = s.flush()
} }
} else { } else {
finish = true finish = true
...@@ -479,6 +507,10 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin ...@@ -479,6 +507,10 @@ func (s *Server) Transaction(ctx context.Context, client gat.Client, query strin
} }
func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { func (s *Server) CopyIn(ctx context.Context, client gat.Client) error {
err := client.Flush()
if err != nil {
return err
}
for { for {
// detect a disconneted /hanging client by waiting 30 seoncds, else timeout // detect a disconneted /hanging client by waiting 30 seoncds, else timeout
// otherwise, just keep reading packets until a done or error is received // otherwise, just keep reading packets until a done or error is received
...@@ -489,10 +521,11 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { ...@@ -489,10 +521,11 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error {
case pkt = <-client.Recv(): case pkt = <-client.Recv():
case <-cctx.Done(): case <-cctx.Done():
_ = s.writePacket(new(protocol.CopyFail)) _ = s.writePacket(new(protocol.CopyFail))
_ = s.flush()
return cctx.Err() return cctx.Err()
} }
cancel() cancel()
err := s.writePacket(pkt) err = s.writePacket(pkt)
if err != nil { if err != nil {
return err return err
} }
...@@ -500,7 +533,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error { ...@@ -500,7 +533,7 @@ func (s *Server) CopyIn(ctx context.Context, client gat.Client) error {
switch pkt.(type) { switch pkt.(type) {
case *protocol.CopyDone, *protocol.CopyFail: case *protocol.CopyDone, *protocol.CopyFail:
// don't error on copyfail because the client is the one that errored, it already knows // don't error on copyfail because the client is the one that errored, it already knows
return nil return s.flush()
} }
} }
} }
...@@ -510,6 +543,10 @@ func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall) ...@@ -510,6 +543,10 @@ func (s *Server) CallFunction(client gat.Client, payload *protocol.FunctionCall)
if err != nil { if err != nil {
return err return err
} }
err = s.flush()
if err != nil {
return err
}
// read responses // read responses
return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) { return s.forwardTo(client, func(pkt protocol.Packet) (forward bool, finish bool, err error) {
switch pkt.(type) { switch pkt.(type) {
......
...@@ -134,11 +134,6 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { ...@@ -134,11 +134,6 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error {
// TODO: TLS // TODO: TLS
func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error {
err := c.(*net.TCPConn).SetNoDelay(false)
if err != nil {
return err
}
cl := client.NewClient(g, g.c, c, false) cl := client.NewClient(g, g.c, c, false)
func() { func() {
...@@ -152,10 +147,11 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error { ...@@ -152,10 +147,11 @@ func (g *Gatling) handleConnection(ctx context.Context, c net.Conn) error {
delete(g.clients, cl.Id()) delete(g.clients, cl.Id())
}() }()
err = cl.Accept(ctx) err := cl.Accept(ctx)
if err != nil { if err != nil {
log.Println("err in connection:", err.Error()) log.Println("err in connection:", err.Error())
_ = cl.Send(pg_error.IntoPacket(err)) _ = cl.Send(pg_error.IntoPacket(err))
_ = cl.Flush()
} }
_ = c.Close() _ = c.Close()
return nil return nil
......
...@@ -19,6 +19,7 @@ type Client interface { ...@@ -19,6 +19,7 @@ type Client interface {
SetCurrentConn(conn Connection) SetCurrentConn(conn Connection)
Send(pkt protocol.Packet) error Send(pkt protocol.Packet) error
Flush() error
Recv() <-chan protocol.Packet Recv() <-chan protocol.Packet
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment