good morning!!!!

Skip to content
Snippets Groups Projects
read.go 9.72 KiB
Newer Older
Anmol Sethi's avatar
Anmol Sethi committed
// +build !js

Anmol Sethi's avatar
Anmol Sethi committed
package websocket

import (
	"context"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"strings"
	"sync/atomic"
	"time"
Anmol Sethi's avatar
Anmol Sethi committed

	"nhooyr.io/websocket/internal/errd"
Anmol Sethi's avatar
Anmol Sethi committed
)

Anmol Sethi's avatar
Anmol Sethi committed
// Reader reads from the connection until until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
Anmol Sethi's avatar
Anmol Sethi committed
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
Anmol Sethi's avatar
Anmol Sethi committed
// Call CloseRead if you do not expect any data messages from the peer.
Anmol Sethi's avatar
Anmol Sethi committed
//
// Only one Reader may be open at a time.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	return c.reader(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Read is a convenience method around Reader to read a single message
// from the connection.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
	typ, r, err := c.Reader(ctx)
	if err != nil {
		return 0, nil, err
	}

	b, err := ioutil.ReadAll(r)
	return typ, b, err
}

Anmol Sethi's avatar
Anmol Sethi committed
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
Anmol Sethi's avatar
Anmol Sethi committed
// The returned context will be cancelled when the connection is closed.
//
Anmol Sethi's avatar
Anmol Sethi committed
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) CloseRead(ctx context.Context) context.Context {
	ctx, cancel := context.WithCancel(ctx)
	go func() {
		defer cancel()
		c.Reader(ctx)
		c.Close(StatusPolicyViolation, "unexpected data message")
	}()
	return ctx
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
Anmol Sethi's avatar
Anmol Sethi committed
	c.msgReader.limitReader.limit.Store(n)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func newMsgReader(c *Conn) *msgReader {
	mr := &msgReader{
		c:   c,
		fin: true,
	}

	mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768)
Anmol Sethi's avatar
Anmol Sethi committed
	if c.flate() && mr.flateContextTakeover() {
Anmol Sethi's avatar
Anmol Sethi committed
		mr.initFlateReader()
	}

	return mr
}

func (mr *msgReader) initFlateReader() {
Anmol Sethi's avatar
Anmol Sethi committed
	mr.flateReader = getFlateReader(readerFunc(mr.read))
	mr.limitReader.r = mr.flateReader
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) close() {
Anmol Sethi's avatar
Anmol Sethi committed
	mr.c.readMu.Lock(context.Background())
	defer mr.c.readMu.Unlock()

Anmol Sethi's avatar
Anmol Sethi committed
	mr.returnFlateReader()
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) flateContextTakeover() bool {
Anmol Sethi's avatar
Anmol Sethi committed
	if mr.c.client {
Anmol Sethi's avatar
Anmol Sethi committed
		return !mr.c.copts.serverNoContextTakeover
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	return !mr.c.copts.clientNoContextTakeover
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readRSV1Illegal(h header) bool {
Anmol Sethi's avatar
Anmol Sethi committed
	// If compression is enabled, rsv1 is always illegal.
Anmol Sethi's avatar
Anmol Sethi committed
	if !c.flate() {
Anmol Sethi's avatar
Anmol Sethi committed
		return true
	}
	// rsv1 is only allowed on data frames beginning messages.
	if h.opcode != opText && h.opcode != opBinary {
		return true
	}
	return false
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readLoop(ctx context.Context) (header, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	for {
Anmol Sethi's avatar
Anmol Sethi committed
		h, err := c.readFrameHeader(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return header{}, err
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
Anmol Sethi's avatar
Anmol Sethi committed
			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if !c.client && !h.masked {
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, errors.New("received unmasked frame from client")
		}

		switch h.opcode {
		case opClose, opPing, opPong:
Anmol Sethi's avatar
Anmol Sethi committed
			err = c.handleControl(ctx, h)
Anmol Sethi's avatar
Anmol Sethi committed
			if err != nil {
				// Pass through CloseErrors when receiving a close frame.
				if h.opcode == opClose && CloseStatus(err) != -1 {
					return header{}, err
				}
				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
			}
		case opContinuation, opText, opBinary:
			return h, nil
		default:
			err := fmt.Errorf("received unknown opcode %v", h.opcode)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}
	}
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return header{}, c.closeErr
	case c.readTimeout <- ctx:
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	h, err := readFrameHeader(c.br)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		select {
Anmol Sethi's avatar
Anmol Sethi committed
		case <-c.closed:
			return header{}, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		case <-ctx.Done():
			return header{}, ctx.Err()
		default:
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return header{}, c.closeErr
	case c.readTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}

	return h, nil
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return 0, c.closeErr
	case c.readTimeout <- ctx:
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	n, err := io.ReadFull(c.br, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		select {
Anmol Sethi's avatar
Anmol Sethi committed
		case <-c.closed:
			return n, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		case <-ctx.Done():
			return n, ctx.Err()
		default:
			err = fmt.Errorf("failed to read frame payload: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
			return n, err
		}
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return n, c.closeErr
	case c.readTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}

	return n, err
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
		err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}

	if !h.fin {
		err := errors.New("received fragmented control frame")
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}

	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
	defer cancel()

Anmol Sethi's avatar
Anmol Sethi committed
	b := c.readControlBuf[:h.payloadLength]
Anmol Sethi's avatar
Anmol Sethi committed
	_, err = c.readFramePayload(ctx, b)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return err
	}

	if h.masked {
		mask(h.maskKey, b)
	}

	switch h.opcode {
	case opPing:
Anmol Sethi's avatar
Anmol Sethi committed
		return c.writeControl(ctx, opPong, b)
Anmol Sethi's avatar
Anmol Sethi committed
	case opPong:
Anmol Sethi's avatar
Anmol Sethi committed
		c.activePingsMu.Lock()
		pong, ok := c.activePings[string(b)]
		c.activePingsMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed
		if ok {
			close(pong)
		}
		return nil
	}

Anmol Sethi's avatar
Anmol Sethi committed
	defer func() {
		c.readCloseFrameErr = err
	}()

Anmol Sethi's avatar
Anmol Sethi committed
	ce, err := parseClosePayload(b)
	if err != nil {
		err = fmt.Errorf("received invalid close payload: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}

	err = fmt.Errorf("received close frame: %w", ce)
Anmol Sethi's avatar
Anmol Sethi committed
	c.setCloseErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
	c.writeClose(ce.Code, ce.Reason)
Anmol Sethi's avatar
Anmol Sethi committed
	c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
	return err
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
	defer errd.Wrap(&err, "failed to get reader")

	err = c.readMu.Lock(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return 0, nil, err
	}
Anmol Sethi's avatar
Anmol Sethi committed
	defer c.readMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	if !c.msgReader.fin {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, errors.New("previous message not read to completion")
	}

Anmol Sethi's avatar
Anmol Sethi committed
	h, err := c.readLoop(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return 0, nil, err
	}

	if h.opcode == opContinuation {
		err := errors.New("received continuation frame without text or binary frame")
Anmol Sethi's avatar
Anmol Sethi committed
		c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	c.msgReader.reset(ctx, h)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	return MessageType(h.opcode), c.msgReader, nil
Anmol Sethi's avatar
Anmol Sethi committed
}

type msgReader struct {
Anmol Sethi's avatar
Anmol Sethi committed
	c *Conn
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	ctx         context.Context
	deflate     bool
	flateReader io.Reader
	deflateTail strings.Reader
	limitReader *limitReader
Anmol Sethi's avatar
Anmol Sethi committed
	fin           bool
Anmol Sethi's avatar
Anmol Sethi committed
	payloadLength int64
	maskKey       uint32
}

func (mr *msgReader) reset(ctx context.Context, h header) {
	mr.ctx = ctx
	mr.deflate = h.rsv1
	if mr.deflate {
Anmol Sethi's avatar
Anmol Sethi committed
		if !mr.flateContextTakeover() {
Anmol Sethi's avatar
Anmol Sethi committed
			mr.initFlateReader()
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed
		mr.deflateTail.Reset(deflateMessageTail)
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

	mr.limitReader.reset()
Anmol Sethi's avatar
Anmol Sethi committed
	mr.setFrame(h)
}

func (mr *msgReader) setFrame(h header) {
Anmol Sethi's avatar
Anmol Sethi committed
	mr.fin = h.fin
Anmol Sethi's avatar
Anmol Sethi committed
	mr.payloadLength = h.payloadLength
	mr.maskKey = h.maskKey
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) Read(p []byte) (n int, err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	defer func() {
Anmol Sethi's avatar
Anmol Sethi committed
		r := recover()
		if r != nil {
			if r != "ANMOL" {
				panic(r)
			}
			err = io.EOF
			if !mr.flateContextTakeover() {
				mr.returnFlateReader()
			}
		}

Anmol Sethi's avatar
Anmol Sethi committed
		errd.Wrap(&err, "failed to read")
		if errors.Is(err, io.EOF) {
			err = io.EOF
		}
	}()

Anmol Sethi's avatar
Anmol Sethi committed
	err = mr.c.readMu.Lock(mr.ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return 0, err
	}
Anmol Sethi's avatar
Anmol Sethi committed
	defer mr.c.readMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	return mr.limitReader.Read(p)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) returnFlateReader() {
	if mr.flateReader != nil {
		putFlateReader(mr.flateReader)
		mr.flateReader = nil
	}
}

Anmol Sethi's avatar
Anmol Sethi committed
func (mr *msgReader) read(p []byte) (int, error) {
	if mr.payloadLength == 0 {
Anmol Sethi's avatar
Anmol Sethi committed
		if mr.fin {
			if mr.deflate {
				if mr.deflateTail.Len() == 0 {
					panic("ANMOL")
				}
				n, _ := mr.deflateTail.Read(p)
				return n, nil
			}
			return 0, io.EOF
Anmol Sethi's avatar
Anmol Sethi committed
		}

		h, err := mr.c.readLoop(mr.ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return 0, err
		}
		if h.opcode != opContinuation {
			err := errors.New("received new data message without finishing the previous message")
Anmol Sethi's avatar
Anmol Sethi committed
			mr.c.writeError(StatusProtocolError, err)
Anmol Sethi's avatar
Anmol Sethi committed
			return 0, err
		}
		mr.setFrame(h)
	}

	if int64(len(p)) > mr.payloadLength {
		p = p[:mr.payloadLength]
	}

Anmol Sethi's avatar
Anmol Sethi committed
	n, err := mr.c.readFramePayload(mr.ctx, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return n, err
	}

	mr.payloadLength -= int64(n)

Anmol Sethi's avatar
Anmol Sethi committed
	if !mr.c.client {
Anmol Sethi's avatar
Anmol Sethi committed
		mr.maskKey = mask(mr.maskKey, p)
	}

	return n, nil
}

type limitReader struct {
	c     *Conn
	r     io.Reader
	limit atomicInt64
	n     int64
}

func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
	lr := &limitReader{
		c: c,
	}
	lr.limit.Store(limit)
Anmol Sethi's avatar
Anmol Sethi committed
	lr.r = r
	lr.reset()
Anmol Sethi's avatar
Anmol Sethi committed
	return lr
}

Anmol Sethi's avatar
Anmol Sethi committed
func (lr *limitReader) reset() {
Anmol Sethi's avatar
Anmol Sethi committed
	lr.n = lr.limit.Load()
Anmol Sethi's avatar
Anmol Sethi committed
func (lr *limitReader) Read(p []byte) (int, error) {
	if lr.n <= 0 {
		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
Anmol Sethi's avatar
Anmol Sethi committed
		lr.c.writeError(StatusMessageTooBig, err)
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, err
	}

	if int64(len(p)) > lr.n {
		p = p[:lr.n]
	}
	n, err := lr.r.Read(p)
	lr.n -= int64(n)
	return n, err
}

type atomicInt64 struct {
	i atomic.Value
}

func (v *atomicInt64) Load() int64 {
	i, _ := v.i.Load().(int64)
	return i
}

func (v *atomicInt64) Store(i int64) {
	v.i.Store(i)
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
	return f(p)
}