good morning!!!!

Skip to content
Snippets Groups Projects
conn.go 21.9 KiB
Newer Older
// +build !js

package websocket
Anmol Sethi's avatar
Anmol Sethi committed

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bufio"
Anmol Sethi's avatar
Anmol Sethi committed
	"context"
Anmol Sethi's avatar
Anmol Sethi committed
	"errors"
Anmol Sethi's avatar
Anmol Sethi committed
	"fmt"
	"io"
Anmol Sethi's avatar
Anmol Sethi committed
	"io/ioutil"
Anmol Sethi's avatar
Anmol Sethi committed
	"math/rand"
Anmol Sethi's avatar
Anmol Sethi committed
	"runtime"
Anmol Sethi's avatar
Anmol Sethi committed
	"strconv"
Anmol Sethi's avatar
Anmol Sethi committed
	"sync"
	"sync/atomic"
Anmol Sethi's avatar
Anmol Sethi committed
	"time"
Anmol Sethi's avatar
Anmol Sethi committed
)

// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
Anmol Sethi's avatar
Anmol Sethi committed
//
Anmol Sethi's avatar
Anmol Sethi committed
// You must always read from the connection. Otherwise control
Anmol Sethi's avatar
Anmol Sethi committed
// frames will not be handled. See the docs on Reader and CloseRead.
Anmol Sethi's avatar
Anmol Sethi committed
//
Anmol Sethi's avatar
Anmol Sethi committed
// Be sure to call Close on the connection when you
Anmol Sethi's avatar
Anmol Sethi committed
// are finished with it to release the associated resources.
//
// Every error from Read or Reader will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
Anmol Sethi's avatar
Anmol Sethi committed
type Conn struct {
	subprotocol string
	br          *bufio.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	bw          *bufio.Writer
	// writeBuf is used for masking, its the buffer in bufio.Writer.
Anmol Sethi's avatar
Anmol Sethi committed
	// Only used by the client for masking the bytes in the buffer.
	writeBuf []byte
	closer   io.Closer
	client   bool
Anmol Sethi's avatar
Anmol Sethi committed

	closeOnce    sync.Once
	closeErrOnce sync.Once
	closeErr     error
	closed       chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed

	// messageWriter state.
	// writeMsgLock is acquired to write a data message.
	writeMsgLock chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed
	// writeFrameLock is acquired to write a single frame.
	// Effectively meaning whoever holds it gets to write to bw.
	writeFrameLock chan struct{}
	writeHeaderBuf []byte
	writeHeader    *header
	// read limit for a message in bytes.
	msgReadLimit *atomicInt64
	// Used to ensure a previous writer is not used after being closed.
	activeWriter atomic.Value
	// messageWriter state.
	writeMsgOpcode opcode
	writeMsgCtx    context.Context
	readMsgLeft    int64
	// Used to ensure the previous reader is read till EOF before allowing
	// a new one.
	activeReader *messageReader
Anmol Sethi's avatar
Anmol Sethi committed
	// readFrameLock is acquired to read from bw.
	readFrameLock     chan struct{}
	isReadClosed      *atomicInt64
	readHeaderBuf     []byte
	controlPayloadBuf []byte
	// messageReader state.
	readerMsgCtx    context.Context
	readerMsgHeader header
	readerFrameEOF  bool
	readerMaskPos   int
	setReadTimeout  chan context.Context
	setWriteTimeout chan context.Context
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	activePingsMu sync.Mutex
	activePings   map[string]chan<- struct{}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) init() {
	c.closed = make(chan struct{})
	c.msgReadLimit = &atomicInt64{}
	c.msgReadLimit.Store(32768)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	c.writeMsgLock = make(chan struct{}, 1)
	c.writeFrameLock = make(chan struct{}, 1)

	c.readFrameLock = make(chan struct{}, 1)
	c.setReadTimeout = make(chan context.Context)
	c.setWriteTimeout = make(chan context.Context)
Anmol Sethi's avatar
Anmol Sethi committed
	c.activePings = make(map[string]chan<- struct{})
Anmol Sethi's avatar
Anmol Sethi committed

	c.writeHeaderBuf = makeWriteHeaderBuf()
	c.writeHeader = &header{}
	c.readHeaderBuf = makeReadHeaderBuf()
	c.isReadClosed = &atomicInt64{}
	c.controlPayloadBuf = make([]byte, maxControlFramePayload)

Anmol Sethi's avatar
Anmol Sethi committed
	runtime.SetFinalizer(c, func(c *Conn) {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(errors.New("connection garbage collected"))
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
	return c.subprotocol
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) close(err error) {
	c.closeOnce.Do(func() {
		runtime.SetFinalizer(c, nil)
		c.setCloseErr(err)
Anmol Sethi's avatar
Anmol Sethi committed
		close(c.closed)
Anmol Sethi's avatar
Anmol Sethi committed
		// Have to close after c.closed is closed to ensure any goroutine that wakes up
		// from the connection being closed also sees that c.closed is closed and returns
		// closeErr.
		c.closer.Close()
Anmol Sethi's avatar
Anmol Sethi committed
		// See comment on bufioReaderPool in handshake.go
Anmol Sethi's avatar
Anmol Sethi committed
		if c.client {
			// By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer
			// and we can safely return them.
			// Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent
			// a deadlock.
			// As of now, this is in writeFrame, readFramePayload and readHeader.
			c.readFrameLock <- struct{}{}
			returnBufioReader(c.br)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeFrameLock <- struct{}{}
			returnBufioWriter(c.bw)
Anmol Sethi's avatar
Anmol Sethi committed
	})
func (c *Conn) timeoutLoop() {
	readCtx := context.Background()
	writeCtx := context.Background()
Anmol Sethi's avatar
Anmol Sethi committed
	for {
		select {
		case <-c.closed:
			return
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		case writeCtx = <-c.setWriteTimeout:
		case readCtx = <-c.setReadTimeout:
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
			c.close(fmt.Errorf("read timed out: %w", readCtx.Err()))
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
	select {
	case <-ctx.Done():
		var err error
		switch lock {
		case c.writeFrameLock, c.writeMsgLock:
Anmol Sethi's avatar
Anmol Sethi committed
			err = fmt.Errorf("could not acquire write lock: %v", ctx.Err())
Anmol Sethi's avatar
Anmol Sethi committed
		case c.readFrameLock:
Anmol Sethi's avatar
Anmol Sethi committed
			err = fmt.Errorf("could not acquire read lock: %v", ctx.Err())
Anmol Sethi's avatar
Anmol Sethi committed
		default:
			panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err()))
		}
		c.close(err)
		return ctx.Err()
	case <-c.closed:
		return c.closeErr
	case lock <- struct{}{}:
		return nil
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) releaseLock(lock chan struct{}) {
	// Allow multiple releases.
	select {
	case <-lock:
	default:
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	for {
Anmol Sethi's avatar
Anmol Sethi committed
		h, err := c.readFrameHeader(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return header{}, err
Anmol Sethi's avatar
Anmol Sethi committed
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if h.rsv1 || h.rsv2 || h.rsv3 {
			c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
			return header{}, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		}

		if h.opcode.controlOp() {
Anmol Sethi's avatar
Anmol Sethi committed
			err = c.handleControl(ctx, h)
			if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
				return header{}, fmt.Errorf("failed to handle control frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
			}
Anmol Sethi's avatar
Anmol Sethi committed
		switch h.opcode {
		case opBinary, opText, opContinuation:
			return h, nil
Anmol Sethi's avatar
Anmol Sethi committed
		default:
			c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode))
			return header{}, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
	err := c.acquireLock(context.Background(), c.readFrameLock)
	if err != nil {
		return header{}, err
	}
	defer c.releaseLock(c.readFrameLock)
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		return header{}, c.closeErr
	case c.setReadTimeout <- ctx:
	}

	h, err := readHeader(c.readHeaderBuf, c.br)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-c.closed:
			return header{}, c.closeErr
		case <-ctx.Done():
			err = ctx.Err()
		default:
		}
Anmol Sethi's avatar
Anmol Sethi committed
		err := fmt.Errorf("failed to read header: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
		c.releaseLock(c.readFrameLock)
		c.close(err)
		return header{}, err
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		return header{}, c.closeErr
	case c.setReadTimeout <- context.Background():
	}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) handleControl(ctx context.Context, h header) error {
Anmol Sethi's avatar
Anmol Sethi committed
	if h.payloadLength > maxControlFramePayload {
		c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength))
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	if !h.fin {
		c.Close(StatusProtocolError, "received fragmented control frame")
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
Anmol Sethi's avatar
Anmol Sethi committed
	defer cancel()

	b := c.controlPayloadBuf[:h.payloadLength]
Anmol Sethi's avatar
Anmol Sethi committed
	_, err := c.readFramePayload(ctx, b)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	}

	if h.masked {
		fastXOR(h.maskKey, 0, b)
	}

	switch h.opcode {
	case opPing:
Anmol Sethi's avatar
Anmol Sethi committed
		return c.writePong(b)
Anmol Sethi's avatar
Anmol Sethi committed
	case opPong:
		c.activePingsMu.Lock()
		pong, ok := c.activePings[string(b)]
		c.activePingsMu.Unlock()
		if ok {
			close(pong)
		}
Anmol Sethi's avatar
Anmol Sethi committed
		return nil
Anmol Sethi's avatar
Anmol Sethi committed
	case opClose:
		ce, err := parseClosePayload(b)
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			err = fmt.Errorf("received invalid close payload: %w", err)
			c.Close(StatusProtocolError, err.Error())
			return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		}
		// This ensures the closeErr of the Conn is always the received CloseError
		// in case the echo close frame write fails.
		// See https://github.com/nhooyr/websocket/issues/109
Anmol Sethi's avatar
Anmol Sethi committed
		c.setCloseErr(fmt.Errorf("received close frame: %w", ce))
		c.writeClose(b, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	default:
		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Reader waits until there is a WebSocket data message to read
// from the connection.
// It returns the type of the message and a reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
// All returned errors will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
//
Anmol Sethi's avatar
Anmol Sethi committed
// You must read from the connection for control frames to be handled.
Anmol Sethi's avatar
Anmol Sethi committed
// Thus if you expect messages to take a long time to be responded to,
// you should handle such messages async to reading from the connection
// to ensure control frames are promptly handled.
//
Anmol Sethi's avatar
Anmol Sethi committed
// If you do not expect any data messages from the peer, call CloseRead.
Anmol Sethi's avatar
Anmol Sethi committed
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and then the message
// Read, use time.AfterFunc to cancel the context passed in early.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
	if c.isReadClosed.Load() == 1 {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, fmt.Errorf("websocket connection read closed")
Anmol Sethi's avatar
Anmol Sethi committed
	typ, r, err := c.reader(ctx)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, fmt.Errorf("failed to get reader: %w", err)
	return typ, r, nil
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
	if c.activeReader != nil && !c.readerFrameEOF {
Anmol Sethi's avatar
Anmol Sethi committed
		// The only way we know for sure the previous reader is not yet complete is
		// if there is an active frame not yet fully read.
		// Otherwise, a user may have read the last byte but not the EOF if the EOF
		// is in the next frame so we check for that below.
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, nil, fmt.Errorf("previous message not read to completion")
Anmol Sethi's avatar
Anmol Sethi committed
	h, err := c.readTillMsg(ctx)
	if err != nil {
		return 0, nil, err
	}
Anmol Sethi's avatar
Anmol Sethi committed

	if c.activeReader != nil && !c.activeReader.eof() {
Anmol Sethi's avatar
Anmol Sethi committed
		if h.opcode != opContinuation {
			c.Close(StatusProtocolError, "received new data message without finishing the previous message")
			return 0, nil, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		if !h.fin || h.payloadLength > 0 {
Anmol Sethi's avatar
Anmol Sethi committed
			return 0, nil, fmt.Errorf("previous message not read to completion")
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed

		c.activeReader = nil
Anmol Sethi's avatar
Anmol Sethi committed

		h, err = c.readTillMsg(ctx)
		if err != nil {
			return 0, nil, err
		}
Anmol Sethi's avatar
Anmol Sethi committed
	} else if h.opcode == opContinuation {
		c.Close(StatusProtocolError, "received continuation frame not after data or text frame")
		return 0, nil, c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

	c.readerMsgCtx = ctx
	c.readerMsgHeader = h
	c.readerFrameEOF = false
	c.readerMaskPos = 0
	c.readMsgLeft = c.msgReadLimit.Load()
Anmol Sethi's avatar
Anmol Sethi committed

	r := &messageReader{
Anmol Sethi's avatar
Anmol Sethi committed
	}
	c.activeReader = r
Anmol Sethi's avatar
Anmol Sethi committed
	return MessageType(h.opcode), r, nil
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
	c *Conn
}

func (r *messageReader) eof() bool {
	return r.c.activeReader != r
Anmol Sethi's avatar
Anmol Sethi committed
// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
	n, err := r.read(p)
	if err != nil {
		// Have to return io.EOF directly for now, we cannot wrap as errors.Is
		// isn't used widely yet.
Anmol Sethi's avatar
Anmol Sethi committed
		if errors.Is(err, io.EOF) {
Anmol Sethi's avatar
Anmol Sethi committed
			return n, io.EOF
		}
Anmol Sethi's avatar
Anmol Sethi committed
		return n, fmt.Errorf("failed to read: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return n, nil
}

func (r *messageReader) read(p []byte) (int, error) {
	if r.eof() {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, fmt.Errorf("cannot use EOFed reader")
Anmol Sethi's avatar
Anmol Sethi committed
	}

	if r.c.readMsgLeft <= 0 {
		r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit))
		return 0, r.c.closeErr
	if int64(len(p)) > r.c.readMsgLeft {
		p = p[:r.c.readMsgLeft]
	if r.c.readerFrameEOF {
		h, err := r.c.readTillMsg(r.c.readerMsgCtx)
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
			return 0, err
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed

		if h.opcode != opContinuation {
			r.c.Close(StatusProtocolError, "received new data message without finishing the previous message")
			return 0, r.c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		}
		r.c.readerMsgHeader = h
		r.c.readerFrameEOF = false
		r.c.readerMaskPos = 0
Anmol Sethi's avatar
Anmol Sethi committed
	}

	h := r.c.readerMsgHeader
	if int64(len(p)) > h.payloadLength {
		p = p[:h.payloadLength]
Anmol Sethi's avatar
Anmol Sethi committed
	}

	n, err := r.c.readFramePayload(r.c.readerMsgCtx, p)
Anmol Sethi's avatar
Anmol Sethi committed

	h.payloadLength -= int64(n)
	r.c.readMsgLeft -= int64(n)
		r.c.readerMaskPos = fastXOR(h.maskKey, r.c.readerMaskPos, p)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	r.c.readerMsgHeader = h
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return n, err
Anmol Sethi's avatar
Anmol Sethi committed
	}
	if h.payloadLength == 0 {
		r.c.readerFrameEOF = true
Anmol Sethi's avatar
Anmol Sethi committed

			r.c.activeReader = nil
Anmol Sethi's avatar
Anmol Sethi committed
			return n, io.EOF
		}
Anmol Sethi's avatar
Anmol Sethi committed
	return n, nil
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
	err := c.acquireLock(ctx, c.readFrameLock)
	if err != nil {
		return 0, err
	}
	defer c.releaseLock(c.readFrameLock)

Anmol Sethi's avatar
Anmol Sethi committed
	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return 0, c.closeErr
	case c.setReadTimeout <- ctx:
	}

	n, err := io.ReadFull(c.br, p)
	if err != nil {
		select {
		case <-c.closed:
			return n, c.closeErr
		case <-ctx.Done():
			err = ctx.Err()
Anmol Sethi's avatar
Anmol Sethi committed
		err = fmt.Errorf("failed to read frame payload: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
		c.releaseLock(c.readFrameLock)
		c.close(err)
Anmol Sethi's avatar
Anmol Sethi committed
		return n, err
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
		return n, c.closeErr
	case c.setReadTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

	return n, err
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Read is a convenience method to read a single message from the connection.
//
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
// The docs on Reader apply to this method as well.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
	typ, r, err := c.Reader(ctx)
	if err != nil {
		return 0, nil, err
Anmol Sethi's avatar
Anmol Sethi committed

	b, err := ioutil.ReadAll(r)
	return typ, b, err
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
Anmol Sethi's avatar
Anmol Sethi committed
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
	wc, err := c.writer(ctx, typ)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return nil, fmt.Errorf("failed to get writer: %w", err)
	}
	return wc, nil
}

func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	err := c.acquireLock(ctx, c.writeMsgLock)
	c.writeMsgCtx = ctx
	c.writeMsgOpcode = opcode(typ)
	w := &messageWriter{
	c.activeWriter.Store(w)
	return w, nil
Anmol Sethi's avatar
Anmol Sethi committed
}

// Write is a convenience method to write a message to the connection.
//
Anmol Sethi's avatar
Anmol Sethi committed
// See the Writer method if you want to stream a message.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("failed to write msg: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	err := c.acquireLock(ctx, c.writeMsgLock)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
	}
	defer c.releaseLock(c.writeMsgLock)

	n, err := c.writeFrame(ctx, true, opcode(typ), p)
	return n, err
Anmol Sethi's avatar
Anmol Sethi committed
}

// messageWriter enables writing to a WebSocket connection.
type messageWriter struct {
	c *Conn
}

func (w *messageWriter) closed() bool {
	return w != w.c.activeWriter.Load()
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

// Write writes the given bytes to the WebSocket connection.
func (w *messageWriter) Write(p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	n, err := w.write(p)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return n, fmt.Errorf("failed to write: %w", err)
func (w *messageWriter) write(p []byte) (int, error) {
	if w.closed() {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, fmt.Errorf("cannot use closed writer")
	n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
Anmol Sethi's avatar
Anmol Sethi committed
		return n, fmt.Errorf("failed to write data frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	w.c.writeMsgOpcode = opContinuation
Anmol Sethi's avatar
Anmol Sethi committed
}

// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
	err := w.close()
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("failed to close writer: %w", err)
func (w *messageWriter) close() error {
	if w.closed() {
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("cannot use closed writer")
	w.c.activeWriter.Store((*messageWriter)(nil))
	_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("failed to write fin frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	w.c.releaseLock(w.c.writeMsgLock)
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
	_, err := c.writeFrame(ctx, true, opcode, p)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("failed to write control frame: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	return nil
Anmol Sethi's avatar
Anmol Sethi committed
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	err := c.acquireLock(ctx, c.writeFrameLock)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
	defer c.releaseLock(c.writeFrameLock)
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
	case c.setWriteTimeout <- ctx:
	}
	c.writeHeader.fin = fin
	c.writeHeader.opcode = opcode
	c.writeHeader.masked = c.client
	c.writeHeader.payloadLength = int64(len(p))

	if c.client {
		_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			return 0, fmt.Errorf("failed to generate masking key: %w", err)
		}
	}

	n, err := c.realWriteFrame(ctx, *c.writeHeader, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return n, err
	}
Anmol Sethi's avatar
Anmol Sethi committed
	// We already finished writing, no need to potentially brick the connection if
	// the context expires.
	select {
	case <-c.closed:
		return n, c.closeErr
	case c.setWriteTimeout <- context.Background():
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	return n, nil
}

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	defer func() {
		if err != nil {
			select {
			case <-c.closed:
				err = c.closeErr
			case <-ctx.Done():
				err = ctx.Err()
			default:
			}

Anmol Sethi's avatar
Anmol Sethi committed
			err = fmt.Errorf("failed to write %v frame: %w", h.opcode, err)
Anmol Sethi's avatar
Anmol Sethi committed
			// We need to release the lock first before closing the connection to ensure
			// the lock can be acquired inside close to ensure no one can access c.bw.
			c.releaseLock(c.writeFrameLock)
			c.close(err)
		}
	}()

Anmol Sethi's avatar
Anmol Sethi committed
	headerBytes := writeHeader(c.writeHeaderBuf, h)
	_, err = c.bw.Write(headerBytes)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, err
	}

	if c.client {
		var keypos int
		for len(p) > 0 {
			if c.bw.Available() == 0 {
				err = c.bw.Flush()
				if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
					return n, err
				}
			}

			// Start of next write in the buffer.
			i := c.bw.Buffered()

			p2 := p
			if len(p) > c.bw.Available() {
				p2 = p[:c.bw.Available()]
			}

			n2, err := c.bw.Write(p2)
			if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
				return n, err
			}

			keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])

			p = p[n2:]
			n += n2
		}
	} else {
		n, err = c.bw.Write(p)
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			return n, err
Anmol Sethi's avatar
Anmol Sethi committed
	if h.fin {
Anmol Sethi's avatar
Anmol Sethi committed
		err = c.bw.Flush()
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			return n, err
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writePong(p []byte) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opPong, p)
	return err
}
Anmol Sethi's avatar
Anmol Sethi committed
// Close closes the WebSocket connection with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5 seconds.
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
Anmol Sethi's avatar
Anmol Sethi committed
// This does not perform a WebSocket close handshake.
// See https://github.com/nhooyr/websocket/issues/103 for details on why.
//
Anmol Sethi's avatar
Anmol Sethi committed
// The maximum length of reason must be 125 bytes otherwise an internal
// error will be sent to the peer. For this reason, you should avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection.
func (c *Conn) Close(code StatusCode, reason string) error {
	err := c.exportedClose(code, reason)
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("failed to close websocket connection: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	return nil
}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) exportedClose(code StatusCode, reason string) error {
	ce := CloseError{
		Code:   code,
		Reason: reason,
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	// This function also will not wait for a close frame from the peer like the RFC
	// wants because that makes no sense and I don't think anyone actually follows that.
	// Definitely worth seeing what popular browsers do later.
	p, err := ce.bytes()
	if err != nil {
		log.Printf("websocket: failed to marshal close frame: %+v", err)
Anmol Sethi's avatar
Anmol Sethi committed
		ce = CloseError{
			Code: StatusInternalError,
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed
		p, _ = ce.bytes()
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	// CloseErrors sent are made opaque to prevent applications from thinking
	// they received a given status.
Anmol Sethi's avatar
Anmol Sethi committed
	sentErr := fmt.Errorf("sent close frame: %v", ce)
	err = c.writeClose(p, sentErr)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	if !errors.Is(c.closeErr, sentErr) {
Anmol Sethi's avatar
Anmol Sethi committed
		return c.closeErr
	}

	return nil
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeClose(p []byte, cerr error) error {
Anmol Sethi's avatar
Anmol Sethi committed
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()
Anmol Sethi's avatar
Anmol Sethi committed

	// If this fails, the connection had to have died.
Anmol Sethi's avatar
Anmol Sethi committed
	err := c.writeControl(ctx, opClose, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	}

Anmol Sethi's avatar
Anmol Sethi committed
	c.close(cerr)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	return nil
Anmol Sethi's avatar
Anmol Sethi committed
}

func init() {
	rand.Seed(time.Now().UnixNano())
}

// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
Anmol Sethi's avatar
Anmol Sethi committed
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
Anmol Sethi's avatar
Anmol Sethi committed
//
// TCP Keepalives should suffice for most use cases.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Ping(ctx context.Context) error {
	id := rand.Uint64()
	p := strconv.FormatUint(id, 10)

	err := c.ping(ctx, p)
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		return fmt.Errorf("failed to ping: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return nil
}

func (c *Conn) ping(ctx context.Context, p string) error {
Anmol Sethi's avatar
Anmol Sethi committed
	pong := make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
	c.activePingsMu.Lock()
	c.activePings[p] = pong
	c.activePingsMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed

	defer func() {
Anmol Sethi's avatar
Anmol Sethi committed
		c.activePingsMu.Lock()
		delete(c.activePings, p)
		c.activePingsMu.Unlock()
Anmol Sethi's avatar
Anmol Sethi committed
	}()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	err := c.writeControl(ctx, opPing, []byte(p))
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		return err
	}

	select {
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	case <-ctx.Done():
Anmol Sethi's avatar
Anmol Sethi committed
		err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(err)
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	case <-pong:
		return nil
	}
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
	return f(p)
}

// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and stores it in c.writeBuf.
func (c *Conn) extractBufioWriterBuf(w io.Writer) {
	c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
		c.writeBuf = p2[:cap(p2)]
		return len(p2), nil
	}))

	c.bw.WriteByte(0)
	c.bw.Flush()

	c.bw.Reset(w)
}