good morning!!!!

Skip to content
Snippets Groups Projects
read.go 10.7 KiB
Newer Older
Anmol Sethi's avatar
Anmol Sethi committed
package websocket

import (
	"bufio"
	"context"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"strings"
	"sync/atomic"
	"time"
Anmol Sethi's avatar
Anmol Sethi committed

	"nhooyr.io/websocket/internal/errd"
Anmol Sethi's avatar
Anmol Sethi committed
)

// Reader waits until there is a WebSocket data message to read
// from the connection.
// It returns the type of the message and a reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// All returned errors will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
//
// You must read from the connection for control frames to be handled.
// Thus if you expect messages to take a long time to be responded to,
// you should handle such messages async to reading from the connection
// to ensure control frames are promptly handled.
//
// If you do not expect any data messages from the peer, call CloseRead.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and then the message
// Read, use time.AfterFunc to cancel the context passed in early.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
	typ, r, err := c.cr.reader(ctx)
	if err != nil {
		return 0, nil, fmt.Errorf("failed to get reader: %w", err)
	}
	return typ, r, nil
}

// Read is a convenience method to read a single message from the connection.
//
// See the Reader method to reuse buffers or for streaming.
// The docs on Reader apply to this method as well.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
	typ, r, err := c.Reader(ctx)
	if err != nil {
		return 0, nil, err
	}

	b, err := ioutil.ReadAll(r)
	return typ, b, err
}

// CloseRead will start a goroutine to read from the connection until it is closed or a data message
// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
// After calling this method, you cannot read any data messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// Use this when you do not want to read data messages from the connection anymore but will
// want to write messages to it.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
	ctx, cancel := context.WithCancel(ctx)
	go func() {
		defer cancel()
		c.Reader(ctx)
		c.Close(StatusPolicyViolation, "unexpected data message")
	}()
	return ctx
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
	c.cr.mr.lr.limit.Store(n)
}

type connReader struct {
	c       *Conn
	br      *bufio.Reader
	timeout chan context.Context

	mu                mu
	controlPayloadBuf [maxControlPayload]byte
	mr                *msgReader
}

func (cr *connReader) init(c *Conn, br *bufio.Reader) {
	cr.c = c
	cr.br = br
	cr.timeout = make(chan context.Context)

	cr.mr = &msgReader{
		cr:  cr,
		fin: true,
	}

	cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768)
	if c.deflateNegotiated() && cr.contextTakeover() {
		cr.ensureFlateReader()
	}
}

func (cr *connReader) ensureFlateReader() {
	cr.mr.fr = getFlateReader(readerFunc(cr.mr.read))
	cr.mr.lr.reset(cr.mr.fr)
}

func (cr *connReader) close() {
	cr.mu.Lock(context.Background())
	if cr.c.client {
		putBufioReader(cr.br)
	}
	if cr.c.deflateNegotiated() && cr.contextTakeover() {
		putFlateReader(cr.mr.fr)
	}
}

func (cr *connReader) contextTakeover() bool {
	if cr.c.client {
		return cr.c.copts.serverNoContextTakeover
	}
	return cr.c.copts.clientNoContextTakeover
}

func (cr *connReader) rsv1Illegal(h header) bool {
	// If compression is enabled, rsv1 is always illegal.
	if !cr.c.deflateNegotiated() {
		return true
	}
	// rsv1 is only allowed on data frames beginning messages.
	if h.opcode != opText && h.opcode != opBinary {
		return true
	}
	return false
}

func (cr *connReader) loop(ctx context.Context) (header, error) {
	for {
		h, err := cr.frameHeader(ctx)
		if err != nil {
			return header{}, err
		}

		if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 {
			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
			cr.c.cw.error(StatusProtocolError, err)
			return header{}, err
		}

		if !cr.c.client && !h.masked {
			return header{}, errors.New("received unmasked frame from client")
		}

		switch h.opcode {
		case opClose, opPing, opPong:
			err = cr.control(ctx, h)
			if err != nil {
				// Pass through CloseErrors when receiving a close frame.
				if h.opcode == opClose && CloseStatus(err) != -1 {
					return header{}, err
				}
				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
			}
		case opContinuation, opText, opBinary:
			return h, nil
		default:
			err := fmt.Errorf("received unknown opcode %v", h.opcode)
			cr.c.cw.error(StatusProtocolError, err)
			return header{}, err
		}
	}
}

func (cr *connReader) frameHeader(ctx context.Context) (header, error) {
	select {
	case <-cr.c.closed:
		return header{}, cr.c.closeErr
	case cr.timeout <- ctx:
	}

	h, err := readFrameHeader(cr.br)
	if err != nil {
		select {
		case <-cr.c.closed:
			return header{}, cr.c.closeErr
		case <-ctx.Done():
			return header{}, ctx.Err()
		default:
			cr.c.closeWithErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
			return header{}, err
		}
	}

	select {
	case <-cr.c.closed:
		return header{}, cr.c.closeErr
	case cr.timeout <- context.Background():
	}

	return h, nil
}

func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) {
	select {
	case <-cr.c.closed:
		return 0, cr.c.closeErr
	case cr.timeout <- ctx:
	}

	n, err := io.ReadFull(cr.br, p)
	if err != nil {
		select {
		case <-cr.c.closed:
			return n, cr.c.closeErr
		case <-ctx.Done():
			return n, ctx.Err()
		default:
			err = fmt.Errorf("failed to read frame payload: %w", err)
			cr.c.closeWithErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
			return n, err
		}
	}

	select {
	case <-cr.c.closed:
		return n, cr.c.closeErr
	case cr.timeout <- context.Background():
	}

	return n, err
}

func (cr *connReader) control(ctx context.Context, h header) error {
	if h.payloadLength < 0 {
		err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength)
		cr.c.cw.error(StatusProtocolError, err)
		return err
	}

	if h.payloadLength > maxControlPayload {
		err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength)
		cr.c.cw.error(StatusProtocolError, err)
		return err
	}

	if !h.fin {
		err := errors.New("received fragmented control frame")
		cr.c.cw.error(StatusProtocolError, err)
		return err
	}

	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
	defer cancel()

	b := cr.controlPayloadBuf[:h.payloadLength]
	_, err := cr.framePayload(ctx, b)
	if err != nil {
		return err
	}

	if h.masked {
		mask(h.maskKey, b)
	}

	switch h.opcode {
	case opPing:
		return cr.c.cw.control(ctx, opPong, b)
	case opPong:
		cr.c.activePingsMu.Lock()
		pong, ok := cr.c.activePings[string(b)]
		cr.c.activePingsMu.Unlock()
		if ok {
			close(pong)
		}
		return nil
	}

	ce, err := parseClosePayload(b)
	if err != nil {
		err = fmt.Errorf("received invalid close payload: %w", err)
		cr.c.cw.error(StatusProtocolError, err)
		return err
	}

	err = fmt.Errorf("received close frame: %w", ce)
	cr.c.setCloseErr(err)
	cr.c.cw.control(context.Background(), opClose, ce.bytes())
	return err
}

func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) {
	err := cr.mu.Lock(ctx)
	if err != nil {
		return 0, nil, err
	}
	defer cr.mu.Unlock()

	if !cr.mr.fin {
		return 0, nil, errors.New("previous message not read to completion")
	}

	h, err := cr.loop(ctx)
	if err != nil {
		return 0, nil, err
	}

	if h.opcode == opContinuation {
		err := errors.New("received continuation frame without text or binary frame")
		cr.c.cw.error(StatusProtocolError, err)
		return 0, nil, err
	}

	cr.mr.reset(ctx, h)

	return MessageType(h.opcode), cr.mr, nil
}

type msgReader struct {
	cr *connReader
	fr io.Reader
	lr *limitReader

	ctx context.Context

	deflate     bool
	deflateTail strings.Reader

	payloadLength int64
	maskKey       uint32
	fin           bool
}

func (mr *msgReader) reset(ctx context.Context, h header) {
	mr.ctx = ctx
	mr.deflate = h.rsv1
	if mr.deflate {
		mr.deflateTail.Reset(deflateMessageTail)
		if !mr.cr.contextTakeover() {
			mr.cr.ensureFlateReader()
		}
	}
	mr.setFrame(h)
	mr.fin = false
}

func (mr *msgReader) setFrame(h header) {
	mr.payloadLength = h.payloadLength
	mr.maskKey = h.maskKey
	mr.fin = h.fin
}

func (mr *msgReader) Read(p []byte) (_ int, err error) {
	defer func() {
		errd.Wrap(&err, "failed to read")
		if errors.Is(err, io.EOF) {
			err = io.EOF
		}
	}()

	err = mr.cr.mu.Lock(mr.ctx)
	if err != nil {
		return 0, err
	}
	defer mr.cr.mu.Unlock()

	if mr.payloadLength == 0 && mr.fin {
		if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() {
			if mr.fr != nil {
				putFlateReader(mr.fr)
				mr.fr = nil
			}
		}
		return 0, io.EOF
	}

	return mr.lr.Read(p)
}

func (mr *msgReader) read(p []byte) (int, error) {
	if mr.payloadLength == 0 {
		h, err := mr.cr.loop(mr.ctx)
		if err != nil {
			return 0, err
		}
		if h.opcode != opContinuation {
			err := errors.New("received new data message without finishing the previous message")
			mr.cr.c.cw.error(StatusProtocolError, err)
			return 0, err
		}
		mr.setFrame(h)
	}

	if int64(len(p)) > mr.payloadLength {
		p = p[:mr.payloadLength]
	}

	n, err := mr.cr.framePayload(mr.ctx, p)
	if err != nil {
		return n, err
	}

	mr.payloadLength -= int64(n)

	if !mr.cr.c.client {
		mr.maskKey = mask(mr.maskKey, p)
	}

	return n, nil
}

type limitReader struct {
	c     *Conn
	r     io.Reader
	limit atomicInt64
	n     int64
}

func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
	lr := &limitReader{
		c: c,
	}
	lr.limit.Store(limit)
	lr.reset(r)
	return lr
}

func (lr *limitReader) reset(r io.Reader) {
	lr.n = lr.limit.Load()
	lr.r = r
}

func (lr *limitReader) Read(p []byte) (int, error) {
	if lr.n <= 0 {
		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
		lr.c.cw.error(StatusMessageTooBig, err)
		return 0, err
	}

	if int64(len(p)) > lr.n {
		p = p[:lr.n]
	}
	n, err := lr.r.Read(p)
	lr.n -= int64(n)
	return n, err
}

type atomicInt64 struct {
	i atomic.Value
}

func (v *atomicInt64) Load() int64 {
	i, _ := v.i.Load().(int64)
	return i
}

func (v *atomicInt64) Store(i int64) {
	v.i.Store(i)
}

type readerFunc func(p []byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
	return f(p)
}