diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index c99a0a7113c5a64535807d3f87f4bc1755ec75db..41d3b07323ab10f6bc386b9977f8b116fde4f270 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -2,7 +2,6 @@ package main import ( "io" - "log" "net" "net/http" _ "net/http/pprof" @@ -117,11 +116,12 @@ func main() { go func() { source := r.NewSource() client := frontends.NewClient(conn) + defer client.Close(nil) done := make(chan struct{}) + defer close(done) for { reader, err := pnet.PreRead(client) if err != nil { - log.Println("failed", err) break } source.Schedule(job{ diff --git a/lib/pnet/ioreader.go b/lib/pnet/ioreader.go index 67f68b7323f7687873c52b593d0e991843ecd5e9..de22cf3a11b7063ddf467a00bf136add8db4e12d 100644 --- a/lib/pnet/ioreader.go +++ b/lib/pnet/ioreader.go @@ -14,7 +14,7 @@ type IOReader struct { reader io.Reader // header buffer for reading packet headers // (allocating within Read would escape to heap) - header [4]byte + header [5]byte buf packet.InBuf payload []byte @@ -22,8 +22,7 @@ type IOReader struct { func MakeIOReader(reader io.Reader) IOReader { return IOReader{ - reader: reader, - payload: make([]byte, 1024), + reader: reader, } } @@ -35,7 +34,8 @@ func NewIOReader(reader io.Reader) *IOReader { // Read fetches the next packet from the underlying io.Reader and gives you a packet.In // Calling Read will invalidate all other packet.In's for this IOReader func (T *IOReader) Read() (packet.In, error) { - typ, err := T.ReadByte() + // read header + _, err := io.ReadFull(T.reader, T.header[:]) if err != nil { return packet.In{}, err } @@ -46,7 +46,7 @@ func (T *IOReader) Read() (packet.In, error) { } T.buf.Reset( - packet.Type(typ), + packet.Type(T.header[0]), T.payload, ) @@ -59,7 +59,13 @@ func (T *IOReader) Read() (packet.In, error) { // ReadUntyped is similar to Read, but it doesn't read a packet.Type func (T *IOReader) ReadUntyped() (packet.In, error) { - err := T.readPayload() + // read header + _, err := io.ReadFull(T.reader, T.header[1:]) + if err != nil { + return packet.In{}, err + } + + err = T.readPayload() if err != nil { return packet.In{}, err } @@ -77,22 +83,12 @@ func (T *IOReader) ReadUntyped() (packet.In, error) { } func (T *IOReader) readPayload() error { - if T.payload == nil { - panic("Previous Read was never finished") - } - - // read length int32 - _, err := io.ReadFull(T.reader, T.header[:]) - if err != nil { - return err - } - - length := binary.BigEndian.Uint32(T.header[:]) - 4 + length := binary.BigEndian.Uint32(T.header[1:]) - 4 // resize body to length T.payload = slices.Resize(T.payload, int(length)) // read body - _, err = io.ReadFull(T.reader, T.payload) + _, err := io.ReadFull(T.reader, T.payload) if err != nil { return err } diff --git a/lib/pnet/iowriter.go b/lib/pnet/iowriter.go index 65fad924f204fb22d6b62a091473513ea61195ac..c1ff75c72ce8f609ecc9353dfd959b2d4ace740f 100644 --- a/lib/pnet/iowriter.go +++ b/lib/pnet/iowriter.go @@ -13,7 +13,7 @@ type IOWriter struct { writer io.Writer // header buffer for writing packet headers // (allocating within Write would escape to heap) - header [4]byte + header [5]byte buf packet.OutBuf } @@ -49,23 +49,25 @@ func (T *IOWriter) write(typ packet.Type, payload []byte) error { log.Println("write untyped packet", payload) } */ - // write type byte (if present) + // prepare header + T.header[0] = byte(typ) + binary.BigEndian.PutUint32(T.header[1:], uint32(len(payload)+4)) + + // write header if typ != packet.None { - err := T.WriteByte(byte(typ)) + _, err := T.writer.Write(T.header[:]) + if err != nil { + return err + } + } else { + _, err := T.writer.Write(T.header[1:]) if err != nil { return err } - } - - // write len+4 - binary.BigEndian.PutUint32(T.header[:], uint32(len(payload)+4)) - _, err := T.writer.Write(T.header[:]) - if err != nil { - return err } // write payload - _, err = T.writer.Write(payload) + _, err := T.writer.Write(payload) if err != nil { return err } diff --git a/lib/pnet/preread.go b/lib/pnet/preread.go index 9399574fe851c333da70f8ff6d0e20ca6d2d6cf1..c8a9963d208843135babe4ab819a5a20026828b9 100644 --- a/lib/pnet/preread.go +++ b/lib/pnet/preread.go @@ -4,6 +4,8 @@ import ( "pggat2/lib/pnet/packet" ) +// PreRead returns a buffered reader containing the first packet +// useful for waiting for a full packet before actually doing work func PreRead(reader Reader) (Reader, error) { in, err := reader.Read() if err != nil { @@ -12,6 +14,7 @@ func PreRead(reader Reader) (Reader, error) { return newPolled(in, reader), nil } +// PreReadUntyped does the same thing as PreReadUntyped but uses Reader.ReadUntyped func PreReadUntyped(reader Reader) (Reader, error) { in, err := reader.ReadUntyped() if err != nil { @@ -20,20 +23,20 @@ func PreReadUntyped(reader Reader) (Reader, error) { return newPolled(in, reader), nil } -type polled struct { +type preRead struct { in packet.In read bool reader Reader } -func newPolled(in packet.In, reader Reader) *polled { - return &polled{ +func newPolled(in packet.In, reader Reader) *preRead { + return &preRead{ in: in, reader: reader, } } -func (T *polled) Read() (packet.In, error) { +func (T *preRead) Read() (packet.In, error) { if !T.read { T.read = true return T.in, nil @@ -41,7 +44,7 @@ func (T *polled) Read() (packet.In, error) { return T.reader.Read() } -func (T *polled) ReadUntyped() (packet.In, error) { +func (T *preRead) ReadUntyped() (packet.In, error) { if !T.read { T.read = true return T.in, nil @@ -49,4 +52,4 @@ func (T *polled) ReadUntyped() (packet.In, error) { return T.reader.ReadUntyped() } -var _ Reader = (*polled)(nil) +var _ Reader = (*preRead)(nil)