diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 2bee445e4e0b5766430bbbd7ded2624e0476d49b..d7638941534aab13703224b98efe79354cd55a61 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -6,15 +6,15 @@ import ( _ "net/http/pprof" "pggat2/lib/backend/backends/v0" + "pggat2/lib/frontend" "pggat2/lib/frontend/frontends/v0" - "pggat2/lib/pnet" "pggat2/lib/rob" "pggat2/lib/rob/schedulers/v2" ) type job struct { - rw pnet.ReadWriter - done chan<- struct{} + client frontend.Client + done chan<- struct{} } func testServer(r rob.Scheduler) { @@ -30,7 +30,7 @@ func testServer(r rob.Scheduler) { sink := r.NewSink(0) for { j := sink.Read().(job) - server.Handle(j.rw) + server.Handle(j.client) select { case j.done <- struct{}{}: default: @@ -62,16 +62,13 @@ func main() { done := make(chan struct{}) defer close(done) for { - reader, err := pnet.PreRead(client) + err := client.Wait() if err != nil { break } source.Schedule(job{ - rw: pnet.JoinedReadWriter{ - Reader: reader, - Writer: client, - }, - done: done, + client: client, + done: done, }, 0) <-done } diff --git a/lib/backend/backends/v0/server.go b/lib/backend/backends/v0/server.go index badea41e277f04dee388a8d58b4f795067d937d2..a3b33f414191331d2b5741a5b1d4d32628594b17 100644 --- a/lib/backend/backends/v0/server.go +++ b/lib/backend/backends/v0/server.go @@ -9,6 +9,8 @@ import ( "pggat2/lib/auth/md5" "pggat2/lib/auth/sasl" "pggat2/lib/backend" + "pggat2/lib/eqp" + "pggat2/lib/frontend" "pggat2/lib/perror" "pggat2/lib/pnet" "pggat2/lib/pnet/packet" @@ -24,19 +26,23 @@ var ErrServerFailed = perror.New( type Server struct { conn net.Conn - pnet.IOReader - pnet.IOWriter + reader pnet.IOReader + writer pnet.IOWriter - cancellationKey [8]byte - parameters map[string]string + cancellationKey [8]byte + parameters map[string]string + preparedStatements map[string]eqp.PreparedStatement + portals map[string]eqp.Portal } func NewServer(conn net.Conn) *Server { server := &Server{ - conn: conn, - IOReader: pnet.MakeIOReader(conn), - IOWriter: pnet.MakeIOWriter(conn), - parameters: make(map[string]string), + conn: conn, + reader: pnet.MakeIOReader(conn), + writer: pnet.MakeIOWriter(conn), + parameters: make(map[string]string), + preparedStatements: make(map[string]eqp.PreparedStatement), + portals: make(map[string]eqp.Portal), } err := server.accept() if err != nil { @@ -296,7 +302,7 @@ func (T *Server) fail() { debug.PrintStack() } -func (T *Server) copyIn0(peer pnet.ReadWriter) (bool, perror.Error) { +func (T *Server) copyIn0(peer frontend.Client) (bool, perror.Error) { in, err := peer.Read() if err != nil { return false, perror.Wrap(err) @@ -322,7 +328,7 @@ func (T *Server) copyIn0(peer pnet.ReadWriter) (bool, perror.Error) { } } -func (T *Server) copyIn(peer pnet.ReadWriter, in packet.In) perror.Error { +func (T *Server) copyIn(peer frontend.Client, in packet.In) perror.Error { // send in (copyInResponse) to client err := pnet.ProxyPacket(peer, in) if err != nil { @@ -343,7 +349,7 @@ func (T *Server) copyIn(peer pnet.ReadWriter, in packet.In) perror.Error { return nil } -func (T *Server) copyOut0(peer pnet.ReadWriter) (bool, perror.Error) { +func (T *Server) copyOut0(peer frontend.Client) (bool, perror.Error) { in, err := T.Read() if err != nil { T.fail() @@ -363,7 +369,7 @@ func (T *Server) copyOut0(peer pnet.ReadWriter) (bool, perror.Error) { } } -func (T *Server) copyOut(peer pnet.ReadWriter, in packet.In) perror.Error { +func (T *Server) copyOut(peer frontend.Client, in packet.In) perror.Error { // send in (copyOutResponse) to server err := pnet.ProxyPacket(T, in) if err != nil { @@ -385,7 +391,7 @@ func (T *Server) copyOut(peer pnet.ReadWriter, in packet.In) perror.Error { return nil } -func (T *Server) query0(peer pnet.ReadWriter) (bool, perror.Error) { +func (T *Server) query0(peer frontend.Client) (bool, perror.Error) { in, err := T.Read() if err != nil { T.fail() @@ -421,7 +427,7 @@ func (T *Server) query0(peer pnet.ReadWriter) (bool, perror.Error) { } } -func (T *Server) query(peer pnet.ReadWriter, in packet.In) perror.Error { +func (T *Server) query(peer frontend.Client, in packet.In) perror.Error { // send in (initial query) to server err := pnet.ProxyPacket(T, in) if err != nil { @@ -442,7 +448,7 @@ func (T *Server) query(peer pnet.ReadWriter, in packet.In) perror.Error { return nil } -func (T *Server) functionCall0(peer pnet.ReadWriter) (bool, perror.Error) { +func (T *Server) functionCall0(peer frontend.Client) (bool, perror.Error) { in, err := T.Read() if err != nil { T.fail() @@ -462,7 +468,7 @@ func (T *Server) functionCall0(peer pnet.ReadWriter) (bool, perror.Error) { } } -func (T *Server) functionCall(peer pnet.ReadWriter, in packet.In) perror.Error { +func (T *Server) functionCall(peer frontend.Client, in packet.In) perror.Error { // send in (FunctionCall) to server err := pnet.ProxyPacket(T, in) if err != nil { @@ -483,7 +489,7 @@ func (T *Server) functionCall(peer pnet.ReadWriter, in packet.In) perror.Error { return nil } -func (T *Server) handle(peer pnet.ReadWriter) perror.Error { +func (T *Server) handle(peer frontend.Client) perror.Error { in, err := peer.Read() if err != nil { return perror.Wrap(err) @@ -493,6 +499,12 @@ func (T *Server) handle(peer pnet.ReadWriter) perror.Error { return T.query(peer, in) case packet.FunctionCall: return T.functionCall(peer, in) + case packet.Sync: + // TODO(garet) send ready for query + return nil + case packet.Flush: + // nothing really to do + return nil default: return perror.New( perror.FATAL, @@ -503,7 +515,7 @@ func (T *Server) handle(peer pnet.ReadWriter) perror.Error { } // Handle handles a transaction from peer, returning when the transaction is complete -func (T *Server) Handle(peer pnet.ReadWriter) { +func (T *Server) Handle(peer frontend.Client) { err := T.handle(peer) if err != nil { out := peer.Write() @@ -512,4 +524,21 @@ func (T *Server) Handle(peer pnet.ReadWriter) { } } +func (T *Server) Write() packet.Out { + T.writer.WriteFunc(T.write) + return T.writer.Write() +} + +func (T *Server) write(typ packet.Type, payload []byte) error { + return T.writer.WriteRaw(typ, payload) +} + +func (T *Server) Read() (packet.In, error) { + return T.reader.Read() +} + +func (T *Server) ReadUntyped() (packet.In, error) { + return T.reader.ReadUntyped() +} + var _ backend.Server = (*Server)(nil) diff --git a/lib/frontend/frontends/v0/portal.go b/lib/eqp/portal.go similarity index 88% rename from lib/frontend/frontends/v0/portal.go rename to lib/eqp/portal.go index 76e1ce07c6568ead294b369d6944b7782ee6d720..c55ff00f78c6f4de44814428cc0c49ae61b7759b 100644 --- a/lib/frontend/frontends/v0/portal.go +++ b/lib/eqp/portal.go @@ -1,4 +1,4 @@ -package frontends +package eqp type Portal struct { Source string diff --git a/lib/frontend/frontends/v0/preparedStatement.go b/lib/eqp/preparedStatement.go similarity index 83% rename from lib/frontend/frontends/v0/preparedStatement.go rename to lib/eqp/preparedStatement.go index 94f86bf2944cee0aca6b841e33a049ce7716beec..5f62b2433b6efb56130bd60aeb638ba7bfee73b3 100644 --- a/lib/frontend/frontends/v0/preparedStatement.go +++ b/lib/eqp/preparedStatement.go @@ -1,4 +1,4 @@ -package frontends +package eqp type PreparedStatement struct { Query string diff --git a/lib/frontend/client.go b/lib/frontend/client.go index 563b14a325f50d78bb6ce02ee3325aad727a3b1b..f6ac36f4888781626521550bf4fa06720508236f 100644 --- a/lib/frontend/client.go +++ b/lib/frontend/client.go @@ -1,7 +1,12 @@ package frontend -import "pggat2/lib/pnet" +import ( + "pggat2/lib/eqp" + "pggat2/lib/pnet" +) type Client interface { pnet.ReadWriter + GetPortal(string) (eqp.Portal, bool) + GetPreparedStatement(string) (eqp.PreparedStatement, bool) } diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go index 0e350c7d21f093a22ec7fcf524d531b0667a4d39..c22fd1f0c4831bf0b8da6651dfa58c5f40f1008b 100644 --- a/lib/frontend/frontends/v0/client.go +++ b/lib/frontend/frontends/v0/client.go @@ -9,6 +9,7 @@ import ( "pggat2/lib/auth/md5" "pggat2/lib/auth/sasl" + "pggat2/lib/eqp" "pggat2/lib/frontend" "pggat2/lib/perror" "pggat2/lib/pnet" @@ -22,8 +23,8 @@ type Client struct { conn net.Conn - pnet.IOReader - pnet.IOWriter + reader pnet.IOReader + writer pnet.IOWriter user string database string @@ -31,18 +32,18 @@ type Client struct { // cancellation key data cancellationKey [8]byte parameters map[string]string - preparedStatements map[string]PreparedStatement - portals map[string]Portal + preparedStatements map[string]eqp.PreparedStatement + portals map[string]eqp.Portal } func NewClient(conn net.Conn) *Client { client := &Client{ conn: conn, - IOReader: pnet.MakeIOReader(conn), - IOWriter: pnet.MakeIOWriter(conn), + reader: pnet.MakeIOReader(conn), + writer: pnet.MakeIOWriter(conn), parameters: make(map[string]string), - preparedStatements: make(map[string]PreparedStatement), - portals: make(map[string]Portal), + preparedStatements: make(map[string]eqp.PreparedStatement), + portals: make(map[string]eqp.Portal), } err := client.accept() if err != nil { @@ -51,6 +52,16 @@ func NewClient(conn net.Conn) *Client { return client } +func (T *Client) GetPreparedStatement(name string) (eqp.PreparedStatement, bool) { + v, ok := T.preparedStatements[name] + return v, ok +} + +func (T *Client) GetPortal(name string) (eqp.Portal, bool) { + v, ok := T.portals[name] + return v, ok +} + func (T *Client) startup0() (bool, perror.Error) { pkt, err := T.ReadUntyped() if err != nil { @@ -78,11 +89,11 @@ func (T *Client) startup0() (bool, perror.Error) { ) case 5679: // SSL is not supported yet - err = T.WriteByte('N') + err = T.writer.WriteByte('N') return false, perror.Wrap(err) case 5680: // GSSAPI is not supported yet - err = T.WriteByte('N') + err = T.writer.WriteByte('N') return false, perror.Wrap(err) default: return false, perror.New( @@ -441,9 +452,33 @@ func (T *Client) accept() perror.Error { return nil } +func (T *Client) Wait() error { + _, err := T.conn.Read(nil) + return err +} + +func (T *Client) Write() packet.Out { + T.writer.WriteFunc(T.write) + return T.writer.Write() +} + +func (T *Client) write(typ packet.Type, payload []byte) error { + inBuf := packet.MakeInBuf(typ, payload) + in := packet.MakeIn(&inBuf) + switch in.Type() { + case packet.ParameterStatus: + parameter, value, ok := packets.ReadParameterStatus(in) + if !ok { + return errors.New("bad packet format") + } + T.parameters[parameter] = value + } + return T.writer.WriteRaw(typ, payload) +} + func (T *Client) Read() (packet.In, error) { for { - in, err := T.IOReader.Read() + in, err := T.reader.Read() if err != nil { return packet.In{}, err } @@ -461,7 +496,7 @@ func (T *Client) Read() (packet.In, error) { if !ok { return packet.In{}, errors.New("bad packet format") } - T.preparedStatements[destination] = PreparedStatement{ + T.preparedStatements[destination] = eqp.PreparedStatement{ Query: query, ParameterDataTypes: parameterDataTypes, } @@ -470,7 +505,7 @@ func (T *Client) Read() (packet.In, error) { if !ok { return packet.In{}, errors.New("bad packet format") } - T.portals[destination] = Portal{ + T.portals[destination] = eqp.Portal{ Source: source, ParameterFormatCodes: parameterFormatCodes, ParameterValues: parameterValues, @@ -495,6 +530,10 @@ func (T *Client) Read() (packet.In, error) { } } +func (T *Client) ReadUntyped() (packet.In, error) { + return T.reader.ReadUntyped() +} + func (T *Client) Close(err perror.Error) { if err != nil { out := T.Write() diff --git a/lib/pnet/bufreader.go b/lib/pnet/bufreader.go deleted file mode 100644 index ac60af2e956e7feb7ce9a5cc22e5fabd1ebbdbd8..0000000000000000000000000000000000000000 --- a/lib/pnet/bufreader.go +++ /dev/null @@ -1,86 +0,0 @@ -package pnet - -import ( - "pggat2/lib/pnet/packet" - "pggat2/lib/util/decorator" - "pggat2/lib/util/ring" -) - -type bufIn struct { - typ packet.Type - start int - length int -} - -type BufReader struct { - noCopy decorator.NoCopy - buf packet.InBuf - payloads []byte - ins ring.Ring[bufIn] - reader Reader -} - -func MakeBufReader(reader Reader) BufReader { - return BufReader{ - reader: reader, - } -} - -func NewBufReader(reader Reader) *BufReader { - v := MakeBufReader(reader) - return &v -} - -func (T *BufReader) Buffer(in packet.In) { - if T.ins.Length() == 0 { - // reset header - T.payloads = T.payloads[:0] - } - start := len(T.payloads) - full := in.Full() - length := len(full) - T.payloads = append(T.payloads, full...) - T.ins.PushBack(bufIn{ - typ: in.Type(), - start: start, - length: length, - }) -} - -func (T *BufReader) Read() (packet.In, error) { - if in, ok := T.ins.PopFront(); ok { - if in.typ == packet.None { - panic("expected typed packet, got untyped") - } - T.buf.Reset( - in.typ, - T.payloads[in.start:in.start+in.length], - ) - // returned buffered packet - return packet.MakeIn( - &T.buf, - ), nil - } - // fall back to underlying - return T.reader.Read() -} - -func (T *BufReader) ReadUntyped() (packet.In, error) { - if in, ok := T.ins.PopFront(); ok { - if in.typ != packet.None { - panic("expected untyped packet, got typed") - } - T.buf.Reset( - packet.None, - T.payloads[in.start:in.start+in.length], - ) - // returned buffered packet - return packet.MakeIn( - &T.buf, - ), nil - } - // fall back to underlying - return T.reader.ReadUntyped() -} - -var _ Reader = (*BufReader)(nil) diff --git a/lib/pnet/iowriter.go b/lib/pnet/iowriter.go index c1ff75c72ce8f609ecc9353dfd959b2d4ace740f..d6c2a65849ea154cf98e840174ce541bc04a4d0e 100644 --- a/lib/pnet/iowriter.go +++ b/lib/pnet/iowriter.go @@ -33,7 +33,7 @@ func NewIOWriter(writer io.Writer) *IOWriter { // Calling Write will invalidate all other packet.Out's for this IOWriter func (T *IOWriter) Write() packet.Out { if !T.buf.Initialized() { - T.buf.Initialize(T.write) + T.buf.Initialize(T.WriteRaw) } T.buf.Reset() @@ -42,7 +42,11 @@ func (T *IOWriter) Write() packet.Out { ) } -func (T *IOWriter) write(typ packet.Type, payload []byte) error { +func (T *IOWriter) WriteFunc(f func(packet.Type, []byte) error) { + T.buf.Initialize(f) +} + +func (T *IOWriter) WriteRaw(typ packet.Type, payload []byte) error { /* if typ != packet.None { log.Printf("write typed packet %c %v\n", typ, payload) } else { diff --git a/lib/pnet/packet/in.go b/lib/pnet/packet/in.go index 7d4094b404e09b734b95dce49439095a9d6fe275..ae846d0ed0140d6fa8ef7824b694086be0c097d6 100644 --- a/lib/pnet/packet/in.go +++ b/lib/pnet/packet/in.go @@ -15,6 +15,13 @@ type InBuf struct { rev int } +func MakeInBuf(typ Type, buf []byte) InBuf { + return InBuf{ + typ: typ, + buf: buf, + } +} + func (T *InBuf) Reset( typ Type, buf []byte, diff --git a/lib/pnet/preread.go b/lib/pnet/preread.go deleted file mode 100644 index c8a9963d208843135babe4ab819a5a20026828b9..0000000000000000000000000000000000000000 --- a/lib/pnet/preread.go +++ /dev/null @@ -1,55 +0,0 @@ -package pnet - -import ( - "pggat2/lib/pnet/packet" -) - -// PreRead returns a buffered reader containing the first packet -// useful for waiting for a full packet before actually doing work -func PreRead(reader Reader) (Reader, error) { - in, err := reader.Read() - if err != nil { - return nil, err - } - return newPolled(in, reader), nil -} - -// PreReadUntyped does the same thing as PreReadUntyped but uses Reader.ReadUntyped -func PreReadUntyped(reader Reader) (Reader, error) { - in, err := reader.ReadUntyped() - if err != nil { - return nil, err - } - return newPolled(in, reader), nil -} - -type preRead struct { - in packet.In - read bool - reader Reader -} - -func newPolled(in packet.In, reader Reader) *preRead { - return &preRead{ - in: in, - reader: reader, - } -} - -func (T *preRead) Read() (packet.In, error) { - if !T.read { - T.read = true - return T.in, nil - } - return T.reader.Read() -} - -func (T *preRead) ReadUntyped() (packet.In, error) { - if !T.read { - T.read = true - return T.in, nil - } - return T.reader.ReadUntyped() -} - -var _ Reader = (*preRead)(nil)