diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index a2199353e02e642481697e32d83f0d5249010158..fc26af1fb8b2db176684fcc25556ebc7aaf36a29 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -1,25 +1,28 @@ package main -import "pggat2/lib/frontend/frontends/v0" +import ( + "net" + + "pggat2/lib/backend/backends/v0" +) func main() { - frontend, err := frontends.NewFrontend() + /*frontend, err := frontends.NewFrontend() if err != nil { panic(err) } err = frontend.Run() + if err != nil { + panic(err) + }*/ + conn, err := net.Dial("tcp", "localhost:5432") + if err != nil { + panic(err) + } + server, err := backends.NewServer(conn) if err != nil { panic(err) } - /* - conn, err := net.Dial("tcp", "localhost:5432") - if err != nil { - panic(err) - } - server, err := backends.NewServer(conn) - if err != nil { - panic(err) - } - _ = server - _ = conn.Close()*/ + _ = server + _ = conn.Close() } diff --git a/lib/backend/backend.go b/lib/backend/backend.go index 7d3570acd2b9589b63812f227fb14377bfe60a1c..555b2c3d8295fe19b238d9dc7422c13ed3e70718 100644 --- a/lib/backend/backend.go +++ b/lib/backend/backend.go @@ -1,6 +1,9 @@ package backend +import "pggat2/lib/pnet" + type Server interface { + pnet.ReadWriter } type Backend interface { diff --git a/lib/backend/backends/v0/server.go b/lib/backend/backends/v0/server.go index 9b86f6ef630939c475e45ebf5c8e45db4c67387e..70d0e0477850086ba8534930e68601e3adec862a 100644 --- a/lib/backend/backends/v0/server.go +++ b/lib/backend/backends/v0/server.go @@ -17,8 +17,8 @@ var ErrProtocolError = errors.New("server sent unexpected packet") type Server struct { conn net.Conn - pnet.Reader - pnet.Writer + pnet.IOReader + pnet.IOWriter cancellationKey [8]byte parameters map[string]string @@ -27,8 +27,8 @@ type Server struct { func NewServer(conn net.Conn) (*Server, error) { server := &Server{ conn: conn, - Reader: pnet.MakeReader(conn), - Writer: pnet.MakeWriter(conn), + IOReader: pnet.MakeIOReader(conn), + IOWriter: pnet.MakeIOWriter(conn), parameters: make(map[string]string), } err := server.accept() @@ -43,7 +43,6 @@ func (T *Server) authenticationSASLChallenge(mechanism sasl.Client) (bool, error if err != nil { return false, err } - defer in.Done() if in.Type() != packet.Authentication { return false, ErrProtocolError @@ -66,7 +65,7 @@ func (T *Server) authenticationSASLChallenge(mechanism sasl.Client) (bool, error out.Type(packet.AuthenticationResponse) out.Bytes(response) - err = out.Done() + err = out.Send() return false, err case 12: // finish @@ -97,7 +96,7 @@ func (T *Server) authenticationSASL(mechanisms []string, username, password stri out.Int32(int32(len(initialResponse))) out.Bytes(initialResponse) } - err = out.Done() + err = out.Send() if err != nil { return err } @@ -120,14 +119,14 @@ func (T *Server) authenticationMD5(salt [4]byte, username, password string) erro out := T.Write() out.Type(packet.AuthenticationResponse) out.String(md5.Encode(username, password, salt)) - return out.Done() + return out.Send() } func (T *Server) authenticationCleartext(password string) error { out := T.Write() out.Type(packet.AuthenticationResponse) out.String(password) - return out.Done() + return out.Send() } func (T *Server) startup0(username, password string) (bool, error) { @@ -138,25 +137,20 @@ func (T *Server) startup0(username, password string) (bool, error) { switch in.Type() { case packet.ErrorResponse: - in.Done() return false, errors.New("received error response") case packet.Authentication: method, ok := in.Int32() if !ok { - in.Done() return false, ErrBadPacketFormat } // they have more authentication methods than there are pokemon switch method { case 0: // we're good to go, that was easy - in.Done() return true, nil case 2: - in.Done() return false, errors.New("kerberos v5 is not supported") case 3: - in.Done() return false, T.authenticationCleartext(password) case 5: var salt [4]byte @@ -164,16 +158,12 @@ func (T *Server) startup0(username, password string) (bool, error) { if !ok { return false, ErrBadPacketFormat } - in.Done() return false, T.authenticationMD5(salt, username, password) case 6: - in.Done() return false, errors.New("scm credential is not supported") case 7: - in.Done() return false, errors.New("gss is not supported") case 9: - in.Done() return false, errors.New("sspi is not supported") case 10: // read list of mechanisms @@ -189,18 +179,14 @@ func (T *Server) startup0(username, password string) (bool, error) { mechanisms = append(mechanisms, mechanism) } - in.Done() return false, T.authenticationSASL(mechanisms, username, password) default: - in.Done() return false, errors.New("unknown authentication method") } case packet.NegotiateProtocolVersion: // we only support protocol 3.0 for now - in.Done() return false, errors.New("server wanted to negotiate protocol version") default: - in.Done() return false, ErrProtocolError } } @@ -210,7 +196,6 @@ func (T *Server) startup1() (bool, error) { if err != nil { return false, err } - defer in.Done() switch in.Type() { case packet.BackendKeyData: @@ -252,7 +237,7 @@ func (T *Server) accept() error { out.String("postgres") out.String("") - err := out.Done() + err := out.Send() if err != nil { return err } @@ -283,4 +268,34 @@ func (T *Server) accept() error { return nil } +func (T *Server) proxy(in *packet.In) error { + out := T.Write() + out.Type(in.Type()) + out.Bytes(in.Full()) + return out.Send() +} + +func (T *Server) query(peer pnet.ReadWriter) error { + return nil +} + +// Transaction handles a transaction from peer, returning when the transaction is complete +func (T *Server) Transaction(peer pnet.ReadWriter) error { + in, err := peer.Read() + if err != nil { + return err + } + switch in.Type() { + case packet.Query: + // proxy to backend + err = T.proxy(&in) + if err != nil { + return err + } + return T.query(peer) + default: + return errors.New("unsupported operation") + } +} + var _ backend.Server = (*Server)(nil) diff --git a/lib/frontend/frontend.go b/lib/frontend/frontend.go index 71829b23718f6f0a223d4704bf0bfec4277654d6..cb7d3827d5ca0eba3b8be9a5377da004d621250b 100644 --- a/lib/frontend/frontend.go +++ b/lib/frontend/frontend.go @@ -1,6 +1,9 @@ package frontend +import "pggat2/lib/pnet" + type Client interface { + pnet.ReadWriter } type Frontend interface { diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go index 0a260d4a9c65fcfcbccbac3a29c7c209819aaa04..6f44fe8b8d8445b27a3538ee46de087c05753999 100644 --- a/lib/frontend/frontends/v0/client.go +++ b/lib/frontend/frontends/v0/client.go @@ -28,8 +28,8 @@ var ErrProtocolError = perror.New( type Client struct { conn net.Conn - pnet.Reader - pnet.Writer + pnet.IOReader + pnet.IOWriter user string database string @@ -42,8 +42,8 @@ type Client struct { func NewClient(conn net.Conn) *Client { client := &Client{ conn: conn, - Reader: pnet.MakeReader(conn), - Writer: pnet.MakeWriter(conn), + IOReader: pnet.MakeIOReader(conn), + IOWriter: pnet.MakeIOWriter(conn), parameters: make(map[string]string), } err := client.accept() @@ -68,7 +68,6 @@ func (T *Client) startup0() (bool, perror.Error) { if err != nil { return false, perror.WrapError(err) } - defer pkt.Done() majorVersion, ok := pkt.Uint16() if !ok { @@ -162,7 +161,7 @@ func (T *Client) startup0() (bool, perror.Error) { out := T.Write() negotiateProtocolVersionPacket(&out, unsupportedOptions) - err = out.Done() + err = out.Send() if err != nil { return false, perror.WrapError(err) } @@ -209,7 +208,6 @@ func (T *Client) authenticationSASLInitial(username, password string) (sasl.Serv if err != nil { return nil, nil, false, perror.WrapError(err) } - defer in.Done() if in.Type() != packet.AuthenticationResponse { return nil, nil, false, ErrBadPacketFormat } @@ -239,7 +237,6 @@ func (T *Client) authenticationSASLContinue(tool sasl.Server) ([]byte, bool, per if err != nil { return nil, false, perror.WrapError(err) } - defer in.Done() if in.Type() != packet.AuthenticationResponse { return nil, false, ErrProtocolError } @@ -254,7 +251,7 @@ func (T *Client) authenticationSASLContinue(tool sasl.Server) ([]byte, bool, per func (T *Client) authenticationSASL(username, password string) perror.Error { out := T.Write() authenticationSASLPacket(&out) - err := out.Done() + err := out.Send() if err != nil { return perror.WrapError(err) } @@ -268,7 +265,7 @@ func (T *Client) authenticationSASL(username, password string) perror.Error { if done { out = T.Write() authenticationSASLFinalPacket(&out, resp) - err = out.Done() + err = out.Send() if err != nil { return perror.WrapError(err) } @@ -276,7 +273,7 @@ func (T *Client) authenticationSASL(username, password string) perror.Error { } else { out = T.Write() authenticationSASLContinuePacket(&out, resp) - err = out.Done() + err = out.Send() if err != nil { return perror.WrapError(err) } @@ -305,7 +302,7 @@ func (T *Client) authenticationMD5(username, password string) perror.Error { out := T.Write() authenticationMD5Packet(&out, salt) - err = out.Done() + err = out.Send() if err != nil { return perror.WrapError(err) } @@ -315,7 +312,6 @@ func (T *Client) authenticationMD5(username, password string) perror.Error { if err != nil { return perror.WrapError(err) } - defer in.Done() if in.Type() != packet.AuthenticationResponse { return perror.New( @@ -350,7 +346,7 @@ func (T *Client) authenticationCleartext(password string) perror.Error { out := T.Write() authenticationCleartextPacket(&out) - err := out.Done() + err := out.Send() if err != nil { return perror.WrapError(err) } @@ -360,7 +356,6 @@ func (T *Client) authenticationCleartext(password string) perror.Error { if err != nil { return perror.WrapError(err) } - defer in.Done() if in.Type() != packet.AuthenticationResponse { return perror.New( @@ -422,7 +417,7 @@ func (T *Client) accept() perror.Error { out := T.Write() authenticationOkPacket(&out) - err := out.Done() + err := out.Send() if err != nil { return perror.WrapError(err) } @@ -435,7 +430,7 @@ func (T *Client) accept() perror.Error { out = T.Write() backendKeyDataPacket(&out, T.cancellationKey) - err = out.Done() + err = out.Send() if err != nil { return perror.WrapError(err) } @@ -444,7 +439,7 @@ func (T *Client) accept() perror.Error { out = T.Write() readyForQueryPacket(&out, 'I') - err = out.Done() + err = out.Send() if err != nil { return perror.WrapError(err) } @@ -473,7 +468,7 @@ func (T *Client) Close(err perror.Error) { out.Uint8(0) - _ = out.Done() + _ = out.Send() } _ = T.conn.Close() } diff --git a/lib/pnet/bufreader.go b/lib/pnet/bufreader.go new file mode 100644 index 0000000000000000000000000000000000000000..1dc6a92d227682461a8b24744eb8e8c82b63b27b --- /dev/null +++ b/lib/pnet/bufreader.go @@ -0,0 +1,84 @@ +package pnet + +import ( + "pggat2/lib/pnet/packet" + "pggat2/lib/util/ring" +) + +type bufIn struct { + typ packet.Type + start int + length int +} + +type BufReader struct { + buf packet.InBuf + payloads []byte + ins ring.Ring[bufIn] + reader Reader +} + +func MakeBufReader(reader Reader) BufReader { + return BufReader{ + reader: reader, + } +} + +func NewBufReader(reader Reader) *BufReader { + v := MakeBufReader(reader) + return &v +} + +func (T *BufReader) Buffer(in *packet.In) { + if T.ins.Length() == 0 { + // reset header + T.payloads = T.payloads[:0] + } + start := len(T.payloads) + full := in.Full() + length := len(full) + T.payloads = append(T.payloads, full...) + T.ins.PushBack(bufIn{ + typ: in.Type(), + start: start, + length: length, + }) +} + +func (T *BufReader) Read() (packet.In, error) { + if in, ok := T.ins.PopFront(); ok { + if in.typ == packet.None { + panic("expected typed packet, got untyped") + } + T.buf.Reset( + in.typ, + T.payloads[in.start:in.start+in.length], + ) + // returned buffered packet + return packet.MakeIn( + &T.buf, + ), nil + } + // fall back to underlying + return T.reader.Read() +} + +func (T *BufReader) ReadUntyped() (packet.In, error) { + if in, ok := T.ins.PopFront(); ok { + if in.typ != packet.None { + panic("expected untyped packet, got typed") + } + T.buf.Reset( + packet.None, + T.payloads[in.start:in.start+in.length], + ) + // returned buffered packet + return packet.MakeIn( + &T.buf, + ), nil + } + // fall back to underlying + return T.reader.ReadUntyped() +} + +var _ Reader = (*BufReader)(nil) diff --git a/lib/pnet/ioreader.go b/lib/pnet/ioreader.go new file mode 100644 index 0000000000000000000000000000000000000000..e8e947c262ae1522ef6bdd2d6a232296fae347bc --- /dev/null +++ b/lib/pnet/ioreader.go @@ -0,0 +1,101 @@ +package pnet + +import ( + "encoding/binary" + "io" + + "pggat2/lib/pnet/packet" + "pggat2/lib/util/slices" +) + +type IOReader struct { + reader io.Reader + // header buffer for reading packet headers + // (allocating within Read would escape to heap) + header [4]byte + + buf packet.InBuf + payload []byte +} + +func MakeIOReader(reader io.Reader) IOReader { + return IOReader{ + reader: reader, + payload: make([]byte, 1024), + } +} + +func NewIOReader(reader io.Reader) *IOReader { + v := MakeIOReader(reader) + return &v +} + +// Read fetches the next packet from the underlying io.Reader and gives you a packet.In +// Calling Read will invalidate all other packet.In's for this IOReader +func (T *IOReader) Read() (packet.In, error) { + typ, err := T.ReadByte() + if err != nil { + return packet.In{}, err + } + + err = T.readPayload() + if err != nil { + return packet.In{}, err + } + + T.buf.Reset( + packet.Type(typ), + T.payload, + ) + + return packet.MakeIn( + &T.buf, + ), nil +} + +// ReadUntyped is similar to Read, but it doesn't read a packet.Type +func (T *IOReader) ReadUntyped() (packet.In, error) { + err := T.readPayload() + if err != nil { + return packet.In{}, err + } + + T.buf.Reset( + packet.None, + T.payload, + ) + + return packet.MakeIn( + &T.buf, + ), nil +} + +func (T *IOReader) readPayload() error { + if T.payload == nil { + panic("Previous Read was never finished") + } + + // read length int32 + _, err := io.ReadFull(T.reader, T.header[:]) + if err != nil { + return err + } + + length := binary.BigEndian.Uint32(T.header[:]) - 4 + + // resize body to length + T.payload = slices.Resize(T.payload, int(length)) + // read body + _, err = io.ReadFull(T.reader, T.payload) + if err != nil { + return err + } + + return nil +} + +func (T *IOReader) ReadByte() (byte, error) { + T.header[0] = 0 + _, err := io.ReadFull(T.reader, T.header[:1]) + return T.header[0], err +} diff --git a/lib/pnet/iowriter.go b/lib/pnet/iowriter.go new file mode 100644 index 0000000000000000000000000000000000000000..6f4a8e91ac7eb5247cc8d74bdc22ef50fd426244 --- /dev/null +++ b/lib/pnet/iowriter.go @@ -0,0 +1,72 @@ +package pnet + +import ( + "encoding/binary" + "io" + + "pggat2/lib/pnet/packet" +) + +type IOWriter struct { + writer io.Writer + // header buffer for writing packet headers + // (allocating within Write would escape to heap) + header [4]byte + + buf packet.OutBuf +} + +func MakeIOWriter(writer io.Writer) IOWriter { + return IOWriter{ + writer: writer, + } +} + +func NewIOWriter(writer io.Writer) *IOWriter { + v := MakeIOWriter(writer) + return &v +} + +// Write gives you a packet.Out +// Calling Write will invalidate all other packet.Out's for this IOWriter +func (T *IOWriter) Write() packet.Out { + if !T.buf.Initialized() { + T.buf.Initialize(T.write) + } + T.buf.Reset() + + return packet.MakeOut( + &T.buf, + ) +} + +func (T *IOWriter) write(typ packet.Type, payload []byte) error { + // write type byte (if present) + if typ != packet.None { + err := T.WriteByte(byte(typ)) + if err != nil { + return err + } + } + + // write len+4 + binary.BigEndian.PutUint32(T.header[:], uint32(len(payload)+4)) + _, err := T.writer.Write(T.header[:]) + if err != nil { + return err + } + + // write payload + _, err = T.writer.Write(payload) + if err != nil { + return err + } + + return nil +} + +func (T *IOWriter) WriteByte(b byte) error { + T.header[0] = b + _, err := T.writer.Write(T.header[:1]) + return err +} diff --git a/lib/pnet/packet/in.go b/lib/pnet/packet/in.go index 9ba0758f93a1193d993eedcc9ec670dda847637f..282bdf7a8c734710007d5bcb19ba0611a34c466b 100644 --- a/lib/pnet/packet/in.go +++ b/lib/pnet/packet/in.go @@ -5,143 +5,152 @@ import ( "math" ) -type In struct { - noCopy noCopy - typ Type - buf []byte - pos int - done bool - finish func([]byte) +type InBuf struct { + typ Type + buf []byte + pos int + rev int } -func MakeIn( +func (T *InBuf) Reset( typ Type, buf []byte, - finish func([]byte), +) { + T.typ = typ + T.buf = buf + T.pos = 0 + T.rev++ +} + +type In struct { + buf *InBuf + rev int +} + +func MakeIn( + buf *InBuf, ) In { return In{ - typ: typ, - buf: buf, - finish: finish, + buf: buf, + rev: buf.rev, } } -func (T *In) Type() Type { - return T.typ +func (T In) done() bool { + return T.rev != T.buf.rev +} + +func (T In) Type() Type { + if T.done() { + panic("Read after Send") + } + return T.buf.typ } // Full returns the full payload of the packet. // NOTE: Full will be invalid after Done is called -func (T *In) Full() []byte { - if T.done { - panic("Read after Done") +func (T In) Full() []byte { + if T.done() { + panic("Read after Send") } - return T.buf + return T.buf.buf } // Remaining returns the remaining payload of the packet. // NOTE: Remaining will be invalid after Done is called -func (T *In) Remaining() []byte { +func (T In) Remaining() []byte { full := T.Full() - return full[T.pos:] + return full[T.buf.pos:] } -func (T *In) Int8() (int8, bool) { +func (T In) Int8() (int8, bool) { v, ok := T.Uint8() return int8(v), ok } -func (T *In) Int16() (int16, bool) { +func (T In) Int16() (int16, bool) { v, ok := T.Uint16() return int16(v), ok } -func (T *In) Int32() (int32, bool) { +func (T In) Int32() (int32, bool) { v, ok := T.Uint32() return int32(v), ok } -func (T *In) Int64() (int64, bool) { +func (T In) Int64() (int64, bool) { v, ok := T.Uint64() return int64(v), ok } -func (T *In) Uint8() (uint8, bool) { +func (T In) Uint8() (uint8, bool) { rem := T.Remaining() if len(rem) < 1 { return 0, false } v := rem[0] - T.pos += 1 + T.buf.pos += 1 return v, true } -func (T *In) Uint16() (uint16, bool) { +func (T In) Uint16() (uint16, bool) { rem := T.Remaining() if len(rem) < 2 { return 0, false } v := binary.BigEndian.Uint16(rem) - T.pos += 2 + T.buf.pos += 2 return v, true } -func (T *In) Uint32() (uint32, bool) { +func (T In) Uint32() (uint32, bool) { rem := T.Remaining() if len(rem) < 4 { return 0, false } v := binary.BigEndian.Uint32(rem) - T.pos += 4 + T.buf.pos += 4 return v, true } -func (T *In) Uint64() (uint64, bool) { +func (T In) Uint64() (uint64, bool) { rem := T.Remaining() if len(rem) < 8 { return 0, false } v := binary.BigEndian.Uint64(rem) - T.pos += 8 + T.buf.pos += 8 return v, true } -func (T *In) Float32() (float32, bool) { +func (T In) Float32() (float32, bool) { v, ok := T.Uint32() return math.Float32frombits(v), ok } -func (T *In) Float64() (float64, bool) { +func (T In) Float64() (float64, bool) { v, ok := T.Uint64() return math.Float64frombits(v), ok } -func (T *In) String() (string, bool) { +func (T In) String() (string, bool) { rem := T.Remaining() for i, c := range rem { if c == 0 { v := string(rem[:i]) - T.pos += i + 1 + T.buf.pos += i + 1 return v, true } } return "", false } -func (T *In) Bytes(b []byte) bool { +func (T In) Bytes(b []byte) bool { rem := T.Remaining() if len(b) > len(rem) { return false } copy(b, rem) - T.pos += len(b) + T.buf.pos += len(b) return true } - -func (T *In) Done() { - if T.done { - panic("Done called twice") - } - T.done = true - T.finish(T.buf) -} diff --git a/lib/pnet/packet/out.go b/lib/pnet/packet/out.go index f1545cb91bc00d871a28cd761206a03e280047f6..9a617297d93a90c12942e700c8e96ecc19decc3e 100644 --- a/lib/pnet/packet/out.go +++ b/lib/pnet/packet/out.go @@ -5,99 +5,123 @@ import ( "math" ) -type Out struct { - noCopy noCopy +type OutBuf struct { typ Type buf []byte - done bool + rev int finish func(Type, []byte) error } +func (T *OutBuf) Initialized() bool { + return T.finish != nil +} + +func (T *OutBuf) Initialize(finish func(Type, []byte) error) { + T.finish = finish +} + +func (T *OutBuf) Reset() { + T.typ = None + T.buf = T.buf[:0] + T.rev++ +} + +type Out struct { + buf *OutBuf + rev int +} + func MakeOut( - buf []byte, - finish func(Type, []byte) error, + buf *OutBuf, ) Out { return Out{ - buf: buf, - finish: finish, + buf: buf, + rev: buf.rev, } } -func (T *Out) Type(typ Type) { - T.typ = typ +func (T Out) done() bool { + return T.rev != T.buf.rev +} + +func (T Out) Type(typ Type) { + if T.done() { + panic("Write after Send") + } + T.buf.typ = typ } -func (T *Out) Int8(v int8) { +func (T Out) Int8(v int8) { T.Uint8(uint8(v)) } -func (T *Out) Int16(v int16) { +func (T Out) Int16(v int16) { T.Uint16(uint16(v)) } -func (T *Out) Int32(v int32) { +func (T Out) Int32(v int32) { T.Uint32(uint32(v)) } -func (T *Out) Int64(v int64) { +func (T Out) Int64(v int64) { T.Uint64(uint64(v)) } -func (T *Out) Uint8(v uint8) { - if T.done { - panic("Write after Done") +func (T Out) Uint8(v uint8) { + if T.done() { + panic("Write after Send") } - T.buf = append(T.buf, v) + T.buf.buf = append(T.buf.buf, v) } -func (T *Out) Uint16(v uint16) { - if T.done { - panic("Write after Done") +func (T Out) Uint16(v uint16) { + if T.done() { + panic("Write after Send") } - T.buf = binary.BigEndian.AppendUint16(T.buf, v) + T.buf.buf = binary.BigEndian.AppendUint16(T.buf.buf, v) } -func (T *Out) Uint32(v uint32) { - if T.done { - panic("Write after Done") +func (T Out) Uint32(v uint32) { + if T.done() { + panic("Write after Send") } - T.buf = binary.BigEndian.AppendUint32(T.buf, v) + T.buf.buf = binary.BigEndian.AppendUint32(T.buf.buf, v) } -func (T *Out) Uint64(v uint64) { - if T.done { - panic("Write after Done") +func (T Out) Uint64(v uint64) { + if T.done() { + panic("Write after Send") } - T.buf = binary.BigEndian.AppendUint64(T.buf, v) + T.buf.buf = binary.BigEndian.AppendUint64(T.buf.buf, v) } -func (T *Out) Float32(v float32) { +func (T Out) Float32(v float32) { T.Uint32(math.Float32bits(v)) } -func (T *Out) Float64(v float64) { +func (T Out) Float64(v float64) { T.Uint64(math.Float64bits(v)) } -func (T *Out) String(v string) { - if T.done { - panic("Write after Done") +func (T Out) String(v string) { + if T.done() { + panic("Write after Send") } - T.buf = append(T.buf, v...) + T.buf.buf = append(T.buf.buf, v...) T.Uint8(0) } -func (T *Out) Bytes(v []byte) { - if T.done { - panic("Write after Done") +func (T Out) Bytes(v []byte) { + if T.done() { + panic("Write after Send") } - T.buf = append(T.buf, v...) + T.buf.buf = append(T.buf.buf, v...) } -func (T *Out) Done() error { - if T.done { - panic("Done called twice") +func (T Out) Send() error { + if T.done() { + panic("Send called twice") } - T.done = true - return T.finish(T.typ, T.buf) + T.buf.rev++ + return T.buf.finish(T.buf.typ, T.buf.buf) } diff --git a/lib/pnet/reader.go b/lib/pnet/reader.go index ccd55b8e1a2064c644bde0f49b41c5754e438961..259cc33bcdddb9880ab354cba63b5d463c7d3624 100644 --- a/lib/pnet/reader.go +++ b/lib/pnet/reader.go @@ -1,99 +1,8 @@ package pnet -import ( - "encoding/binary" - "io" +import "pggat2/lib/pnet/packet" - "pggat2/lib/pnet/packet" - "pggat2/lib/util/slices" -) - -type Reader struct { - reader io.Reader - // buffer for reading packet headers - // (allocating within Read would escape to heap) - buffer [4]byte - payload []byte -} - -func MakeReader(reader io.Reader) Reader { - return Reader{ - reader: reader, - payload: make([]byte, 1024), - } -} - -func NewReader(reader io.Reader) *Reader { - v := MakeReader(reader) - return &v -} - -func (T *Reader) Read() (packet.In, error) { - typ, err := T.ReadByte() - if err != nil { - return packet.In{}, err - } - - err = T.readPayload() - if err != nil { - return packet.In{}, err - } - - payload := T.payload - T.payload = nil - - return packet.MakeIn( - packet.Type(typ), - payload, - func(payload []byte) { - T.payload = payload - }, - ), nil -} - -func (T *Reader) ReadUntyped() (packet.In, error) { - err := T.readPayload() - if err != nil { - return packet.In{}, err - } - - payload := T.payload - T.payload = nil - return packet.MakeIn( - packet.None, - payload, - func(bytes []byte) { - T.payload = payload - }, - ), nil -} - -func (T *Reader) readPayload() error { - if T.payload == nil { - panic("Previous Read was never finished") - } - - // read length int32 - _, err := io.ReadFull(T.reader, T.buffer[:]) - if err != nil { - return err - } - - length := binary.BigEndian.Uint32(T.buffer[:]) - 4 - - // resize body to length - T.payload = slices.Resize(T.payload, int(length)) - // read body - _, err = io.ReadFull(T.reader, T.payload) - if err != nil { - return err - } - - return nil -} - -func (T *Reader) ReadByte() (byte, error) { - T.buffer[0] = 0 - _, err := io.ReadFull(T.reader, T.buffer[:1]) - return T.buffer[0], err +type Reader interface { + Read() (packet.In, error) + ReadUntyped() (packet.In, error) } diff --git a/lib/pnet/readwriter.go b/lib/pnet/readwriter.go new file mode 100644 index 0000000000000000000000000000000000000000..81d125660a867ab52c11808e17355925925ed053 --- /dev/null +++ b/lib/pnet/readwriter.go @@ -0,0 +1,6 @@ +package pnet + +type ReadWriter interface { + Reader + Writer +} diff --git a/lib/pnet/writer.go b/lib/pnet/writer.go index 6f392c4c08a1a5b7e294c56acb40094bfcea7e2a..c5a48e90fcabc77390a9e58d1d20f4cf82f24411 100644 --- a/lib/pnet/writer.go +++ b/lib/pnet/writer.go @@ -1,74 +1,7 @@ package pnet -import ( - "encoding/binary" - "io" +import "pggat2/lib/pnet/packet" - "pggat2/lib/pnet/packet" -) - -type Writer struct { - writer io.Writer - // buffer for writing packet headers - // (allocating within Write would escape to heap) - buffer [4]byte - payload []byte -} - -func MakeWriter(writer io.Writer) Writer { - return Writer{ - writer: writer, - payload: make([]byte, 1024), - } -} - -func NewWriter(writer io.Writer) *Writer { - v := MakeWriter(writer) - return &v -} - -func (T *Writer) Write() packet.Out { - if T.payload == nil { - panic("Previous Write was never finished") - } - - payload := T.payload - T.payload = nil - return packet.MakeOut( - payload[:0], - T.write, - ) -} - -func (T *Writer) write(typ packet.Type, payload []byte) error { - T.payload = payload - - // write type byte (if present) - if typ != packet.None { - err := T.WriteByte(byte(typ)) - if err != nil { - return err - } - } - - // write len+4 - binary.BigEndian.PutUint32(T.buffer[:], uint32(len(payload)+4)) - _, err := T.writer.Write(T.buffer[:]) - if err != nil { - return err - } - - // write payload - _, err = T.writer.Write(payload) - if err != nil { - return err - } - - return nil -} - -func (T *Writer) WriteByte(b byte) error { - T.buffer[0] = b - _, err := T.writer.Write(T.buffer[:1]) - return err +type Writer interface { + Write() packet.Out }