good morning!!!!

Skip to content
Snippets Groups Projects
read.go 9.79 KiB
Newer Older
Anmol Sethi's avatar
Anmol Sethi committed
// +build !js

Anmol Sethi's avatar
Anmol Sethi committed
package websocket

import (
	"context"
	"io"
	"io/ioutil"
	"strings"
	"time"
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	"golang.org/x/xerrors"

Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/errd"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/xsync"
Anmol Sethi's avatar
Anmol Sethi committed
)

Anmol Sethi's avatar
Anmol Sethi committed
// Reader reads from the connection until until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
Anmol Sethi's avatar
Anmol Sethi committed
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
Anmol Sethi's avatar
Anmol Sethi committed
// Call CloseRead if you do not expect any data messages from the peer.
Anmol Sethi's avatar
Anmol Sethi committed
//
// Only one Reader may be open at a time.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	return c.reader(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Read is a convenience method around Reader to read a single message
// from the connection.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
	typ, r, err := c.Reader(ctx)
	if err != nil {
		return 0, nil, err
	}

	b, err := ioutil.ReadAll(r)
	return typ, b, err
}

Anmol Sethi's avatar
Anmol Sethi committed
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
Anmol Sethi's avatar
Anmol Sethi committed
// The returned context will be cancelled when the connection is closed.
//
Anmol Sethi's avatar
Anmol Sethi committed
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) CloseRead(ctx context.Context) context.Context {
	ctx, cancel := context.WithCancel(ctx)
	go func() {
		defer cancel()
		c.Reader(ctx)
		c.Close(StatusPolicyViolation, "unexpected data message")
	}()
	return ctx
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
Anmol Sethi's avatar
Anmol Sethi committed
	c.msgReader.limitReader.limit.Store(n)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
const defaultReadLimit = 32768

Anmol Sethi's avatar
Anmol Sethi committed
func newMsgReader(c *Conn) *msgReader {
	mr := &msgReader{
		c:   c,
		fin: true,
	}

Anmol Sethi's avatar
Anmol Sethi committed
	mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit)
Anmol Sethi's avatar
Anmol Sethi committed
	return mr
}

func (mr *msgReader) resetFlate() {
Anmol Sethi's avatar
Anmol Sethi committed
	if mr.flateContextTakeover() && mr.dict == nil {
		mr.dict = newSlidingWindow(32768)
	}

	if mr.flateContextTakeover() {
		mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf)
	} else {
		mr.flateReader = getFlateReader(readerFunc(mr.read), nil)
	}
Anmol Sethi's avatar
Anmol Sethi committed
	mr.limitReader.r = mr.flateReader
	mr.flateTail.Reset(deflateMessageTail)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) returnFlateReader() {
	if mr.flateReader != nil {
		putFlateReader(mr.flateReader)
		mr.flateReader = nil
	}
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) close() {
Anmol Sethi's avatar
Anmol Sethi committed
	mr.c.readMu.Lock(context.Background())
Anmol Sethi's avatar
Anmol Sethi committed
	mr.returnFlateReader()
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) flateContextTakeover() bool {
Anmol Sethi's avatar
Anmol Sethi committed
	if mr.c.client {
Anmol Sethi's avatar
Anmol Sethi committed
		return !mr.c.copts.serverNoContextTakeover
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	return !mr.c.copts.clientNoContextTakeover
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readRSV1Illegal(h header) bool {
Anmol Sethi's avatar
Anmol Sethi committed
	// If compression is enabled, rsv1 is always illegal.
Anmol Sethi's avatar
Anmol Sethi committed
	if !c.flate() {
Anmol Sethi's avatar
Anmol Sethi committed
		return true
	}
	// rsv1 is only allowed on data frames beginning messages.
	if h.opcode != opText && h.opcode != opBinary {
		return true
	}
	return false
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readLoop(ctx context.Context) (header, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	for {
Anmol Sethi's avatar
Anmol Sethi committed
		h, err := c.readFrameHeader(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return header{}, err
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
Anmol Sethi's avatar
Anmol Sethi committed
			err := xerrors.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if !c.client && !h.masked {
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, xerrors.New("received unmasked frame from client")
Anmol Sethi's avatar
Anmol Sethi committed
		}

		switch h.opcode {
		case opClose, opPing, opPong:
Anmol Sethi's avatar
Anmol Sethi committed
			err = c.handleControl(ctx, h)
Anmol Sethi's avatar
Anmol Sethi committed
			if err != nil {
				// Pass through CloseErrors when receiving a close frame.
				if h.opcode == opClose && CloseStatus(err) != -1 {
					return header{}, err
				}
Anmol Sethi's avatar
Anmol Sethi committed
				return header{}, xerrors.Errorf("failed to handle control frame %v: %w", h.opcode, err)
Anmol Sethi's avatar
Anmol Sethi committed
			}
		case opContinuation, opText, opBinary:
			return h, nil
		default:
Anmol Sethi's avatar
Anmol Sethi committed
			err := xerrors.Errorf("received unknown opcode %v", h.opcode)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}
	}
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return header{}, c.closeErr
	case c.readTimeout <- ctx:
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	h, err := readFrameHeader(c.br)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		select {
Anmol Sethi's avatar
Anmol Sethi committed
		case <-c.closed:
			return header{}, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		case <-ctx.Done():
			return header{}, ctx.Err()
		default:
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return header{}, c.closeErr
	case c.readTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}

	return h, nil
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return 0, c.closeErr
	case c.readTimeout <- ctx:
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	n, err := io.ReadFull(c.br, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		select {
Anmol Sethi's avatar
Anmol Sethi committed
		case <-c.closed:
			return n, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		case <-ctx.Done():
			return n, ctx.Err()
		default:
Anmol Sethi's avatar
Anmol Sethi committed
			err = xerrors.Errorf("failed to read frame payload: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
			return n, err
		}
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return n, c.closeErr
	case c.readTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}

	return n, err
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
Anmol Sethi's avatar
Anmol Sethi committed
		err := xerrors.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}

	if !h.fin {
Anmol Sethi's avatar
Anmol Sethi committed
		err := xerrors.New("received fragmented control frame")
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}

	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
	defer cancel()

Anmol Sethi's avatar
Anmol Sethi committed
	b := c.readControlBuf[:h.payloadLength]
Anmol Sethi's avatar
Anmol Sethi committed
	_, err = c.readFramePayload(ctx, b)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return err
	}

	if h.masked {
		mask(h.maskKey, b)
	}

	switch h.opcode {
	case opPing:
Anmol Sethi's avatar
Anmol Sethi committed
		return c.writeControl(ctx, opPong, b)
Anmol Sethi's avatar
Anmol Sethi committed
	case opPong:
Anmol Sethi's avatar
Anmol Sethi committed
		c.activePingsMu.Lock()
		pong, ok := c.activePings[string(b)]
		c.activePingsMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed
		if ok {
			close(pong)
		}
		return nil
	}

Anmol Sethi's avatar
Anmol Sethi committed
	defer func() {
		c.readCloseFrameErr = err
	}()

Anmol Sethi's avatar
Anmol Sethi committed
	ce, err := parseClosePayload(b)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		err = xerrors.Errorf("received invalid close payload: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	err = xerrors.Errorf("received close frame: %w", ce)
Anmol Sethi's avatar
Anmol Sethi committed
	c.setCloseErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
	c.writeClose(ce.Code, ce.Reason)
Anmol Sethi's avatar
Anmol Sethi committed
	c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
	return err
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
	defer errd.Wrap(&err, "failed to get reader")

	err = c.readMu.Lock(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return 0, nil, err
	}
Anmol Sethi's avatar
Anmol Sethi committed
	defer c.readMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	if !c.msgReader.fin {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, xerrors.New("previous message not read to completion")
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	h, err := c.readLoop(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return 0, nil, err
	}

	if h.opcode == opContinuation {
Anmol Sethi's avatar
Anmol Sethi committed
		err := xerrors.New("received continuation frame without text or binary frame")
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	c.msgReader.reset(ctx, h)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	return MessageType(h.opcode), c.msgReader, nil
Anmol Sethi's avatar
Anmol Sethi committed
}

type msgReader struct {
Anmol Sethi's avatar
Anmol Sethi committed
	c *Conn
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	ctx         context.Context
Anmol Sethi's avatar
Anmol Sethi committed
	flate       bool
Anmol Sethi's avatar
Anmol Sethi committed
	flateReader io.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	flateTail   strings.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	limitReader *limitReader
Anmol Sethi's avatar
Anmol Sethi committed
	dict        *slidingWindow
Anmol Sethi's avatar
Anmol Sethi committed
	fin           bool
Anmol Sethi's avatar
Anmol Sethi committed
	payloadLength int64
	maskKey       uint32
}

func (mr *msgReader) reset(ctx context.Context, h header) {
	mr.ctx = ctx
Anmol Sethi's avatar
Anmol Sethi committed
	mr.flate = h.rsv1
	mr.limitReader.reset(readerFunc(mr.read))

Anmol Sethi's avatar
Anmol Sethi committed
	if mr.flate {
		mr.resetFlate()
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	mr.setFrame(h)
}

func (mr *msgReader) setFrame(h header) {
Anmol Sethi's avatar
Anmol Sethi committed
	mr.fin = h.fin
Anmol Sethi's avatar
Anmol Sethi committed
	mr.payloadLength = h.payloadLength
	mr.maskKey = h.maskKey
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) Read(p []byte) (n int, err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	defer func() {
Anmol Sethi's avatar
Anmol Sethi committed
		errd.Wrap(&err, "failed to read")
		if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
Anmol Sethi's avatar
Anmol Sethi committed
			err = io.EOF
		}
Anmol Sethi's avatar
Anmol Sethi committed
		if xerrors.Is(err, io.EOF) {
Anmol Sethi's avatar
Anmol Sethi committed
			err = io.EOF
		}
	}()

Anmol Sethi's avatar
Anmol Sethi committed
	err = mr.c.readMu.Lock(mr.ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return 0, err
	}
Anmol Sethi's avatar
Anmol Sethi committed
	defer mr.c.readMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	n, err = mr.limitReader.Read(p)
	if mr.flate && mr.flateContextTakeover() {
Anmol Sethi's avatar
Anmol Sethi committed
		p = p[:n]
		mr.dict.write(p)
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	return n, err
Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) read(p []byte) (int, error) {
	if mr.payloadLength == 0 {
Anmol Sethi's avatar
Anmol Sethi committed
		if mr.fin {
Anmol Sethi's avatar
Anmol Sethi committed
			if mr.flate {
				n, err := mr.flateTail.Read(p)
				if xerrors.Is(err, io.EOF) {
					mr.returnFlateReader()
Anmol Sethi's avatar
Anmol Sethi committed
				}
Anmol Sethi's avatar
Anmol Sethi committed
				return n, err
Anmol Sethi's avatar
Anmol Sethi committed
			}
			return 0, io.EOF
Anmol Sethi's avatar
Anmol Sethi committed
		}

		h, err := mr.c.readLoop(mr.ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return 0, err
		}
		if h.opcode != opContinuation {
Anmol Sethi's avatar
Anmol Sethi committed
			err := xerrors.New("received new data message without finishing the previous message")
Anmol Sethi's avatar
Anmol Sethi committed
			mr.c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
			return 0, err
		}
		mr.setFrame(h)
	}

	if int64(len(p)) > mr.payloadLength {
		p = p[:mr.payloadLength]
	}

Anmol Sethi's avatar
Anmol Sethi committed
	n, err := mr.c.readFramePayload(mr.ctx, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return n, err
	}

	mr.payloadLength -= int64(n)

Anmol Sethi's avatar
Anmol Sethi committed
	if !mr.c.client {
Anmol Sethi's avatar
Anmol Sethi committed
		mr.maskKey = mask(mr.maskKey, p)
	}

	return n, nil
}

type limitReader struct {
	c     *Conn
	r     io.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	limit xsync.Int64
Anmol Sethi's avatar
Anmol Sethi committed
	n     int64
}

func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
	lr := &limitReader{
		c: c,
	}
	lr.limit.Store(limit)
Anmol Sethi's avatar
Anmol Sethi committed
	return lr
}

func (lr *limitReader) reset(r io.Reader) {
Anmol Sethi's avatar
Anmol Sethi committed
	lr.n = lr.limit.Load()
Anmol Sethi's avatar
Anmol Sethi committed
func (lr *limitReader) Read(p []byte) (int, error) {
	if lr.n <= 0 {
Anmol Sethi's avatar
Anmol Sethi committed
		err := xerrors.Errorf("read limited at %v bytes", lr.limit.Load())
Anmol Sethi's avatar
Anmol Sethi committed
		lr.c.writeError(StatusMessageTooBig, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, err
	}

	if int64(len(p)) > lr.n {
		p = p[:lr.n]
	}
	n, err := lr.r.Read(p)
	lr.n -= int64(n)
	return n, err
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
	return f(p)
}