good morning!!!!

Skip to content
Snippets Groups Projects
websocket.go 19.8 KiB
Newer Older
package websocket
Anmol Sethi's avatar
Anmol Sethi committed

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bufio"
Anmol Sethi's avatar
Anmol Sethi committed
	"context"
Anmol Sethi's avatar
Anmol Sethi committed
	"fmt"
	"io"
Anmol Sethi's avatar
Anmol Sethi committed
	"io/ioutil"
	"math/rand"
Anmol Sethi's avatar
Anmol Sethi committed
	"os"
Anmol Sethi's avatar
Anmol Sethi committed
	"runtime"
Anmol Sethi's avatar
Anmol Sethi committed
	"strconv"
Anmol Sethi's avatar
Anmol Sethi committed
	"sync"
	"time"

	"golang.org/x/xerrors"
Anmol Sethi's avatar
Anmol Sethi committed
)

// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader, Read
// and SetReadLimit.
Anmol Sethi's avatar
Anmol Sethi committed
//
// Please be sure to call Close on the connection when you
Anmol Sethi's avatar
Anmol Sethi committed
// are finished with it to release the associated resources.
Anmol Sethi's avatar
Anmol Sethi committed
type Conn struct {
	subprotocol string
	br          *bufio.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	bw          *bufio.Writer
	closer      io.Closer
	client      bool
Anmol Sethi's avatar
Anmol Sethi committed

	// read limit for a message in bytes.
Anmol Sethi's avatar
Anmol Sethi committed
	msgReadLimit int64

Anmol Sethi's avatar
Anmol Sethi committed
	closeOnce sync.Once
	closeErr  error
	closed    chan struct{}

	// writeMsgLock is acquired to write a data message.
	writeMsgLock chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed
	// writeFrameLock is acquired to write a single frame.
	// Effectively meaning whoever holds it gets to write to bw.
	// Used to ensure the previous reader is read till EOF before allowing
	// a new one.
	previousReader *messageReader
Anmol Sethi's avatar
Anmol Sethi committed
	// readFrameLock is acquired to read from bw.
Anmol Sethi's avatar
Anmol Sethi committed
	readFrameLock chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed
	// readMsg is used by messageReader to receive frames from
	// readLoop.
	readMsg chan header
Anmol Sethi's avatar
Anmol Sethi committed
	// readMsgDone is used to tell the readLoop to continue after
	// messageReader has read a frame.
	readMsgDone chan struct{}

	setReadTimeout  chan context.Context
	setWriteTimeout chan context.Context
	setConnContext  chan context.Context
	getConnContext  chan context.Context
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	activePingsMu sync.Mutex
	activePings   map[string]chan<- struct{}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) init() {
	c.closed = make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
	c.msgReadLimit = 32768

Anmol Sethi's avatar
Anmol Sethi committed
	c.writeMsgLock = make(chan struct{}, 1)
	c.writeFrameLock = make(chan struct{}, 1)

	c.readFrameLock = make(chan struct{}, 1)
Anmol Sethi's avatar
Anmol Sethi committed
	c.readMsg = make(chan header)
	c.readMsgDone = make(chan struct{})
	c.setReadTimeout = make(chan context.Context)
	c.setWriteTimeout = make(chan context.Context)
	c.setConnContext = make(chan context.Context)
	c.getConnContext = make(chan context.Context)
Anmol Sethi's avatar
Anmol Sethi committed
	c.activePings = make(map[string]chan<- struct{})
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	runtime.SetFinalizer(c, func(c *Conn) {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.New("connection garbage collected"))
Anmol Sethi's avatar
Anmol Sethi committed
	go c.readLoop()
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
	return c.subprotocol
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) close(err error) {
	c.closeOnce.Do(func() {
		runtime.SetFinalizer(c, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		c.closeErr = xerrors.Errorf("websocket closed: %w", err)
		close(c.closed)
Anmol Sethi's avatar
Anmol Sethi committed
		// Have to close after c.closed is closed to ensure any goroutine that wakes up
		// from the connection being closed also sees that c.closed is closed and returns
		// closeErr.
		c.closer.Close()
Anmol Sethi's avatar
Anmol Sethi committed
		// See comment in dial.go
		if c.client {
			// By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer
			// and we can safely return them.
			// Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent
			// a deadlock.
			// As of now, this is in writeFrame, readFramePayload and readHeader.
			c.readFrameLock <- struct{}{}
			returnBufioReader(c.br)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeFrameLock <- struct{}{}
			returnBufioWriter(c.bw)
Anmol Sethi's avatar
Anmol Sethi committed
	})
func (c *Conn) timeoutLoop() {
	readCtx := context.Background()
	writeCtx := context.Background()
	parentCtx := context.Background()
Anmol Sethi's avatar
Anmol Sethi committed
	for {
		select {
		case <-c.closed:
			return
Anmol Sethi's avatar
Anmol Sethi committed
		case writeCtx = <-c.setWriteTimeout:
		case readCtx = <-c.setReadTimeout:
		case <-readCtx.Done():
			c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err()))
		case <-writeCtx.Done():
			c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err()))
		case <-parentCtx.Done():
			c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err()))
			return
		case parentCtx = <-c.setConnContext:
Anmol Sethi's avatar
Anmol Sethi committed
			ctx, cancelCtx := context.WithCancel(parentCtx)
			defer cancelCtx()

Anmol Sethi's avatar
Anmol Sethi committed
			select {
			case <-c.closed:
				return
			case c.getConnContext <- ctx:
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
// Context returns a context derived from parent that will be cancelled
// when the connection is closed or broken.
// If the parent context is cancelled, the connection will be closed.
func (c *Conn) Context(parent context.Context) context.Context {
	select {
	case <-c.closed:
		ctx, cancel := context.WithCancel(parent)
		cancel()
		return ctx
	case c.setConnContext <- parent:
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		ctx, cancel := context.WithCancel(parent)
		cancel()
		return ctx
	case ctx := <-c.getConnContext:
		return ctx
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
	select {
	case <-ctx.Done():
		var err error
		switch lock {
		case c.writeFrameLock, c.writeMsgLock:
			err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err())
		case c.readFrameLock:
			err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err())
		default:
			panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err()))
		}
		c.close(err)
		return ctx.Err()
	case <-c.closed:
		return c.closeErr
	case lock <- struct{}{}:
		return nil
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) releaseLock(lock chan struct{}) {
	// Allow multiple releases.
	select {
	case <-lock:
	default:
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readLoop() {
	for {
		h, err := c.readTillMsg()
		if err != nil {
			return
		}
Anmol Sethi's avatar
Anmol Sethi committed

		select {
		case <-c.closed:
			return
		case c.readMsg <- h:
		}

		select {
		case <-c.closed:
			return
		case <-c.readMsgDone:
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readTillMsg() (header, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	for {
Anmol Sethi's avatar
Anmol Sethi committed
		h, err := c.readFrameHeader()
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return header{}, err
Anmol Sethi's avatar
Anmol Sethi committed
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if h.rsv1 || h.rsv2 || h.rsv3 {
Anmol Sethi's avatar
Anmol Sethi committed
			err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
			c.Close(StatusProtocolError, err.Error())
			return header{}, err
Anmol Sethi's avatar
Anmol Sethi committed
		}

		if h.opcode.controlOp() {
Anmol Sethi's avatar
Anmol Sethi committed
			c.handleControl(h)
			continue
		}

Anmol Sethi's avatar
Anmol Sethi committed
		switch h.opcode {
		case opBinary, opText, opContinuation:
			return h, nil
Anmol Sethi's avatar
Anmol Sethi committed
		default:
Anmol Sethi's avatar
Anmol Sethi committed
			err := xerrors.Errorf("received unknown opcode %v", h.opcode)
			c.Close(StatusProtocolError, err.Error())
			return header{}, err
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFrameHeader() (header, error) {
	err := c.acquireLock(context.Background(), c.readFrameLock)
	if err != nil {
		return header{}, err
	}
	defer c.releaseLock(c.readFrameLock)
	h, err := readHeader(c.br)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		err := xerrors.Errorf("failed to read header: %w", err)
		c.releaseLock(c.readFrameLock)
		c.close(err)
		return header{}, err
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) handleControl(h header) {
	if h.payloadLength > maxControlFramePayload {
		c.Close(StatusProtocolError, "control frame too large")
		return
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	if !h.fin {
		c.Close(StatusProtocolError, "control frame cannot be fragmented")
		return
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

Anmol Sethi's avatar
Anmol Sethi committed
	b := make([]byte, h.payloadLength)

	_, err := c.readFramePayload(ctx, b)
	if err != nil {
		return
	}

	if h.masked {
		fastXOR(h.maskKey, 0, b)
	}

	switch h.opcode {
	case opPing:
		c.writePong(b)
	case opPong:
		c.activePingsMu.Lock()
		pong, ok := c.activePings[string(b)]
		c.activePingsMu.Unlock()
		if ok {
			close(pong)
		}
	case opClose:
		ce, err := parseClosePayload(b)
		if err != nil {
			c.close(xerrors.Errorf("received invalid close payload: %w", err))
			return
		}
		if ce.Code == StatusNoStatusRcvd {
			c.writeClose(nil, ce)
		} else {
			c.Close(ce.Code, ce.Reason)
		}
	default:
		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Reader waits until there is a WebSocket data message to read
// from the connection.
// It returns the type of the message and a reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
Anmol Sethi's avatar
Anmol Sethi committed
// Control (ping, pong, close) frames will be handled automatically
// in a separate goroutine so if you do not expect any data messages,
// you do not need  to read from the connection. However, if the peer
// sends a data message, further pings, pongs and close frames will not
// be read if you do not read the message from the connection.
Anmol Sethi's avatar
Anmol Sethi committed
// Only one Reader may be open at a time.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
	typ, r, err := c.reader(ctx)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	return typ, &limitedReader{
		c:    c,
		r:    r,
		left: c.msgReadLimit,
	}, nil
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
	if c.previousReader != nil && c.previousReader.h != nil {
		// The only way we know for sure the previous reader is not yet complete is
		// if there is an active frame not yet fully read.
		// Otherwise, a user may have read the last byte but not the EOF if the EOF
		// is in the next frame so we check for that below.
		return 0, nil, xerrors.Errorf("previous message not read to completion")
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		return 0, nil, c.closeErr
	case <-ctx.Done():
		return 0, nil, ctx.Err()
	case h := <-c.readMsg:
		if c.previousReader != nil && !c.previousReader.done {
			if h.opcode != opContinuation {
				err := xerrors.Errorf("received new data message without finishing the previous message")
				c.Close(StatusProtocolError, err.Error())
				return 0, nil, err
			}

			if !h.fin || h.payloadLength > 0 {
				return 0, nil, xerrors.Errorf("previous message not read to completion")
			}

			c.previousReader.done = true

			select {
			case <-c.closed:
				return 0, nil, c.closeErr
			case c.readMsgDone <- struct{}{}:
			}

			return c.reader(ctx)
		} else if h.opcode == opContinuation {
			err := xerrors.Errorf("received continuation frame not after data or text frame")
			c.Close(StatusProtocolError, err.Error())
			return 0, nil, err
Anmol Sethi's avatar
Anmol Sethi committed

		r := &messageReader{
			ctx: ctx,
			c:   c,

			h: &h,
		}
		c.previousReader = r
		return MessageType(h.opcode), r, nil
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
	ctx context.Context
	c   *Conn

	h       *header
	maskPos int
	done    bool
Anmol Sethi's avatar
Anmol Sethi committed
// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
	n, err := r.read(p)
	if err != nil {
		// Have to return io.EOF directly for now, we cannot wrap as xerrors
		// isn't used in stdlib.
		if xerrors.Is(err, io.EOF) {
			return n, io.EOF
		}
		return n, xerrors.Errorf("failed to read: %w", err)
	}
	return n, nil
}

func (r *messageReader) read(p []byte) (int, error) {
	if r.done {
		return 0, xerrors.Errorf("cannot use EOFed reader")
	}

	if r.h == nil {
		select {
		case <-r.c.closed:
			return 0, r.c.closeErr
		case <-r.ctx.Done():
			r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err()))
			return 0, r.ctx.Err()
		case h := <-r.c.readMsg:
			if h.opcode != opContinuation {
				err := xerrors.Errorf("received new data frame without finishing the previous frame")
				r.c.Close(StatusProtocolError, err.Error())
				return 0, err
			}
			r.h = &h
		}
	}

	if int64(len(p)) > r.h.payloadLength {
		p = p[:r.h.payloadLength]
	}

	n, err := r.c.readFramePayload(r.ctx, p)

	r.h.payloadLength -= int64(n)
	if r.h.masked {
		r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p)
	}
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return n, err
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	if r.h.payloadLength == 0 {
		select {
		case <-r.c.closed:
			return n, r.c.closeErr
		case r.c.readMsgDone <- struct{}{}:
		}

		fin := r.h.fin

		// Need to nil this as Reader uses it to check
		// whether there is active data on the previous reader and
		// now there isn't.
		r.h = nil

		if fin {
			r.done = true
			return n, io.EOF
		}

		r.maskPos = 0
Anmol Sethi's avatar
Anmol Sethi committed
	return n, nil
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
	err := c.acquireLock(ctx, c.readFrameLock)
	if err != nil {
		return 0, err
	}
	defer c.releaseLock(c.readFrameLock)

Anmol Sethi's avatar
Anmol Sethi committed
	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return 0, c.closeErr
	case c.setReadTimeout <- ctx:
	}

	n, err := io.ReadFull(c.br, p)
	if err != nil {
		select {
		case <-c.closed:
			return n, c.closeErr
		case <-ctx.Done():
			err = ctx.Err()
Anmol Sethi's avatar
Anmol Sethi committed
		err = xerrors.Errorf("failed to read from connection: %w", err)
		c.releaseLock(c.readFrameLock)
		c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
		return n, err
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
		return n, c.closeErr
	case c.setReadTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

	return n, err
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusPolicyViolation.
func (c *Conn) SetReadLimit(n int64) {
	c.msgReadLimit = n
}

// Read is a convenience method to read a single message from the connection.
//
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
// The docs on Reader apply to this method as well.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
	typ, r, err := c.Reader(ctx)
	if err != nil {
		return 0, nil, err
Anmol Sethi's avatar
Anmol Sethi committed

	b, err := ioutil.ReadAll(r)
	return typ, b, err
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
Anmol Sethi's avatar
Anmol Sethi committed
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
	wc, err := c.writer(ctx, typ)
	if err != nil {
		return nil, xerrors.Errorf("failed to get writer: %w", err)
	}
	return wc, nil
}

func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	err := c.acquireLock(ctx, c.writeMsgLock)
Anmol Sethi's avatar
Anmol Sethi committed
	return &messageWriter{
		ctx:    ctx,
		opcode: opcode(typ),
		c:      c,
	}, nil
Anmol Sethi's avatar
Anmol Sethi committed
}

// Write is a convenience method to write a message to the connection.
//
Anmol Sethi's avatar
Anmol Sethi committed
// See the Writer method if you want to stream a message. The docs on Writer
// regarding concurrency also apply to this method.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
Anmol Sethi's avatar
Anmol Sethi committed
	err := c.write(ctx, typ, p)
	if err != nil {
		return xerrors.Errorf("failed to write msg: %w", err)
	}
	return nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
	err := c.acquireLock(ctx, c.writeMsgLock)
	if err != nil {
		return err
	}
	defer c.releaseLock(c.writeMsgLock)

	err = c.writeFrame(ctx, true, opcode(typ), p)
	return err
Anmol Sethi's avatar
Anmol Sethi committed
}

// messageWriter enables writing to a WebSocket connection.
type messageWriter struct {
	ctx    context.Context
	opcode opcode
	c      *Conn
	closed bool
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

// Write writes the given bytes to the WebSocket connection.
func (w *messageWriter) Write(p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	n, err := w.write(p)
	if err != nil {
		return n, xerrors.Errorf("failed to write: %w", err)
	}
	return n, nil
}

func (w *messageWriter) write(p []byte) (int, error) {
	if w.closed {
		return 0, xerrors.Errorf("cannot use closed writer")
	}
Anmol Sethi's avatar
Anmol Sethi committed
	err := w.c.writeFrame(w.ctx, false, w.opcode, p)
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, xerrors.Errorf("failed to write data frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	w.opcode = opContinuation
	return len(p), nil
Anmol Sethi's avatar
Anmol Sethi committed
}

// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
	err := w.close()
	if err != nil {
		return xerrors.Errorf("failed to close writer: %w", err)
	}
	return nil
}

func (w *messageWriter) close() error {
	if w.closed {
		return xerrors.Errorf("cannot use closed writer")
Anmol Sethi's avatar
Anmol Sethi committed
	err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		return xerrors.Errorf("failed to write fin frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	w.c.releaseLock(w.c.writeMsgLock)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
	err := c.writeFrame(ctx, true, opcode, p)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return xerrors.Errorf("failed to write control frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	return nil
Anmol Sethi's avatar
Anmol Sethi committed
// writeFrame handles all writes to the connection.
// We never mask inside here because our mask key is always 0,0,0,0.
// See comment on secWebSocketKey for why.
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error {
	h := header{
		fin:           fin,
		opcode:        opcode,
		masked:        c.client,
		payloadLength: int64(len(p)),
	}
	b2 := marshalHeader(h)

	err := c.acquireLock(ctx, c.writeFrameLock)
	if err != nil {
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	defer c.releaseLock(c.writeFrameLock)
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
		return c.closeErr
	case c.setWriteTimeout <- ctx:
	}
Anmol Sethi's avatar
Anmol Sethi committed
	writeErr := func(err error) error {
		select {
		case <-c.closed:
			return c.closeErr
		case <-ctx.Done():
			err = ctx.Err()
		default:
Anmol Sethi's avatar
Anmol Sethi committed
		err = xerrors.Errorf("failed to write to connection: %w", err)
		// We need to release the lock first before closing the connection to ensure
		// the lock can be acquired inside close to ensure no one can access c.bw.
		c.releaseLock(c.writeFrameLock)
		c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
		return err
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	_, err = c.bw.Write(b2)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return writeErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
	_, err = c.bw.Write(p)
	if err != nil {
		return writeErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
	if fin {
		err = c.bw.Flush()
		if err != nil {
			return writeErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
	// We already finished writing, no need to potentially brick the connection if
	// the context expires.
	select {
	case <-c.closed:
		return c.closeErr
	case c.setWriteTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	return nil
}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writePong(p []byte) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opPong, p)
	return err
}
Anmol Sethi's avatar
Anmol Sethi committed
// Close closes the WebSocket connection with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5 seconds.
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes otherwise an internal
// error will be sent to the peer. For this reason, you should avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection.
func (c *Conn) Close(code StatusCode, reason string) error {
	err := c.exportedClose(code, reason)
Anmol Sethi's avatar
Anmol Sethi committed
		return xerrors.Errorf("failed to close connection: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	return nil
}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) exportedClose(code StatusCode, reason string) error {
	ce := CloseError{
		Code:   code,
		Reason: reason,
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	// This function also will not wait for a close frame from the peer like the RFC
	// wants because that makes no sense and I don't think anyone actually follows that.
	// Definitely worth seeing what popular browsers do later.
	p, err := ce.bytes()
	if err != nil {
		fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err)
		ce = CloseError{
			Code: StatusInternalError,
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed
		p, _ = ce.bytes()
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	return c.writeClose(p, ce)
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeClose(p []byte, cerr CloseError) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	err := c.writeControl(ctx, opClose, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	c.close(cerr)
	if !xerrors.Is(c.closeErr, cerr) {
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	return nil
Anmol Sethi's avatar
Anmol Sethi committed
}

func init() {
	rand.Seed(time.Now().UnixNano())
}

// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// TCP Keepalives should suffice for most use cases.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Ping(ctx context.Context) error {
	err := c.ping(ctx)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return xerrors.Errorf("failed to ping: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return nil
}

func (c *Conn) ping(ctx context.Context) error {
	id := rand.Uint64()
	p := strconv.FormatUint(id, 10)

	pong := make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
	c.activePingsMu.Lock()
	c.activePings[p] = pong
	c.activePingsMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed

	defer func() {
Anmol Sethi's avatar
Anmol Sethi committed
		c.activePingsMu.Lock()
		delete(c.activePings, p)
		c.activePingsMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed
	}()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	err := c.writeControl(ctx, opPing, []byte(p))
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return err
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	case <-ctx.Done():
Anmol Sethi's avatar
Anmol Sethi committed
		err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err())
		c.close(err)
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	case <-pong:
		return nil
	}
}