Newer
Older
package websocket
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"strings"
"sync/atomic"
"time"
// Reader reads from the connection until until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := ioutil.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
c.Reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768)
mr.initFlateReader()
}
return mr
}
func (mr *msgReader) initFlateReader() {
mr.flateReader = getFlateReader(readerFunc(mr.read))
mr.limitReader.r = mr.flateReader
mr.c.readMu.Lock(context.Background())
defer mr.c.readMu.Unlock()
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- ctx:
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
return header{}, ctx.Err()
default:
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- context.Background():
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
case <-c.closed:
return 0, c.closeErr
case c.readTimeout <- ctx:
case <-ctx.Done():
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
case <-c.closed:
return n, c.closeErr
case c.readTimeout <- context.Background():
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
if err != nil {
return err
}
if h.masked {
mask(h.maskKey, b)
}
switch h.opcode {
case opPing:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.Lock(ctx)
return 0, nil, errors.New("previous message not read to completion")
}
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
ctx context.Context
deflate bool
flateReader io.Reader
deflateTail strings.Reader
limitReader *limitReader
payloadLength int64
maskKey uint32
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.deflate = h.rsv1
if mr.deflate {
mr.setFrame(h)
}
func (mr *msgReader) setFrame(h header) {
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
r := recover()
if r != nil {
if r != "ANMOL" {
panic(r)
}
err = io.EOF
if !mr.flateContextTakeover() {
mr.returnFlateReader()
}
}
errd.Wrap(&err, "failed to read")
if errors.Is(err, io.EOF) {
err = io.EOF
}
}()
func (mr *msgReader) returnFlateReader() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
func (mr *msgReader) read(p []byte) (int, error) {
if mr.payloadLength == 0 {
if mr.fin {
if mr.deflate {
if mr.deflateTail.Len() == 0 {
panic("ANMOL")
}
n, _ := mr.deflateTail.Read(p)
return n, nil
}
return 0, io.EOF
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
return 0, err
}
mr.setFrame(h)
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
mr.maskKey = mask(mr.maskKey, p)
}
return n, nil
}
type limitReader struct {
c *Conn
r io.Reader
limit atomicInt64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n <= 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
return n, err
}
type atomicInt64 struct {
i atomic.Value
}
func (v *atomicInt64) Load() int64 {
i, _ := v.i.Load().(int64)
return i
}
func (v *atomicInt64) Store(i int64) {
v.i.Store(i)
}
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}