Newer
Older
package websocket
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"strings"
"sync/atomic"
"time"
// Reader reads from the connection until until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := ioutil.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
c.Reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
func (mr *msgReader) ensureFlateReader() {
mr.flateReader = getFlateReader(readerFunc(mr.read))
mr.limitReader.reset(mr.flateReader)
func (mr *msgReader) close() {
if mr.c.deflateNegotiated() && mr.contextTakeover() {
mr.c.readMu.Lock(context.Background())
putFlateReader(mr.flateReader)
mr.c.readMu.Unlock()
func (mr *msgReader) contextTakeover() bool {
if mr.c.client {
return mr.c.copts.serverNoContextTakeover
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- ctx:
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
return header{}, ctx.Err()
default:
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- context.Background():
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
case <-c.closed:
return 0, c.closeErr
case c.readTimeout <- ctx:
case <-ctx.Done():
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
case <-c.closed:
return n, c.closeErr
case c.readTimeout <- context.Background():
func (c *Conn) handleControl(ctx context.Context, h header) error {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.readControlBuf[:h.payloadLength]
_, err := c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
mask(h.maskKey, b)
}
switch h.opcode {
case opPing:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
close(pong)
}
return nil
}
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
c.setCloseErr(err)
c.writeControl(context.Background(), opClose, ce.bytes())
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.Lock(ctx)
return 0, nil, errors.New("previous message not read to completion")
}
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
payloadLength int64
maskKey uint32
fin bool
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.deflate = h.rsv1
if mr.deflate {
mr.deflateTail.Reset(deflateMessageTail)
if !mr.contextTakeover() {
mr.ensureFlateReader()
}
}
mr.setFrame(h)
mr.fin = false
}
func (mr *msgReader) setFrame(h header) {
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
mr.fin = h.fin
}
func (mr *msgReader) Read(p []byte) (_ int, err error) {
defer func() {
errd.Wrap(&err, "failed to read")
if errors.Is(err, io.EOF) {
err = io.EOF
}
}()
if mr.c.deflateNegotiated() && !mr.contextTakeover() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
func (mr *msgReader) read(p []byte) (int, error) {
if mr.payloadLength == 0 {
if mr.fin {
if mr.deflate {
n, _ := mr.deflateTail.Read(p[:4])
return n, nil
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
return 0, err
}
mr.setFrame(h)
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
mr.maskKey = mask(mr.maskKey, p)
}
return n, nil
}
type limitReader struct {
c *Conn
r io.Reader
limit atomicInt64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
lr.reset(r)
return lr
}
func (lr *limitReader) reset(r io.Reader) {
lr.n = lr.limit.Load()
lr.r = r
}
func (lr *limitReader) setLimit(limit int64) {
lr.limit.Store(limit)
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n <= 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
return n, err
}
type atomicInt64 struct {
i atomic.Value
}
func (v *atomicInt64) Load() int64 {
i, _ := v.i.Load().(int64)
return i
}
func (v *atomicInt64) Store(i int64) {
v.i.Store(i)
}
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}