good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 1200707b authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Ensure connection is closed at all error points

Closes #191
parent 43c4dc08
No related branches found
No related tags found
No related merge requests found
......@@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
defer c.readMu.unlock()
if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}
h, err := c.readLoop(ctx)
......@@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
defer func() {
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
err = io.EOF
}
if errors.Is(err, io.EOF) {
err = io.EOF
mr.putFlateReader()
return
}
errd.Wrap(&err, "failed to read")
}()
err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()
......@@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
}
return n, err
}
......
......@@ -155,11 +155,16 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write")
mw.writeMu.Lock()
defer mw.writeMu.Unlock()
defer func() {
err = fmt.Errorf("failed to write: %w", err)
if err != nil {
mw.c.close(err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
......@@ -230,8 +235,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
}
// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
err := c.writeFrameMu.lock(ctx)
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
......@@ -243,6 +248,12 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
case c.writeTimeout <- ctx:
}
defer func() {
if err != nil {
c.close(fmt.Errorf("failed to write frame: %w", err))
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment