good morning!!!!

Skip to content
Snippets Groups Projects
websocket.go 11.9 KiB
Newer Older
package websocket
Anmol Sethi's avatar
Anmol Sethi committed

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bufio"
Anmol Sethi's avatar
Anmol Sethi committed
	"context"
Anmol Sethi's avatar
Anmol Sethi committed
	"fmt"
	"io"
Anmol Sethi's avatar
Anmol Sethi committed
	"runtime"
Anmol Sethi's avatar
Anmol Sethi committed
	"sync"
	"time"

	"golang.org/x/xerrors"
Anmol Sethi's avatar
Anmol Sethi committed
)

type control struct {
	opcode  opcode
	payload []byte
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Conn represents a WebSocket connection.
Anmol Sethi's avatar
Anmol Sethi committed
// Pings will always be automatically responded to with pongs, you do not
// have to do anything special.
type Conn struct {
	subprotocol string
	br          *bufio.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	bw          *bufio.Writer
	closer      io.Closer
	client      bool
Anmol Sethi's avatar
Anmol Sethi committed

	closeOnce sync.Once
	closeErr  error
	closed    chan struct{}

	// Writers should send on write to begin sending
	// a message and then follow that up with some data
	// on writeBytes.
	write      chan MessageType
	control    chan control
Anmol Sethi's avatar
Anmol Sethi committed
	writeBytes chan []byte
	writeDone  chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed

	// Readers should receive on read to begin reading a message.
	// Then send a byte slice to readBytes to read into it.
Anmol Sethi's avatar
Anmol Sethi committed
	// The n of bytes read will be sent on readDone once the read into a slice is complete.
	// readDone will receive 0 when EOF is reached.
	read       chan opcode
	readBytes  chan []byte
	readDone   chan int
	readerDone chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed
}

func (c *Conn) close(err error) {
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		err = xerrors.Errorf("websocket: connection broken: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}

	c.closeOnce.Do(func() {
Anmol Sethi's avatar
Anmol Sethi committed
		runtime.SetFinalizer(c, nil)

Anmol Sethi's avatar
Anmol Sethi committed
		c.closeErr = err

		cerr := c.closer.Close()
		if c.closeErr == nil {
			c.closeErr = cerr
		}

		close(c.closed)
	})
}
Anmol Sethi's avatar
Anmol Sethi committed

// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Subprotocol() string {
Anmol Sethi's avatar
Anmol Sethi committed
	return c.subprotocol
}

func (c *Conn) init() {
	c.closed = make(chan struct{})
	c.write = make(chan MessageType)
	c.control = make(chan control)
	c.writeDone = make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
	c.read = make(chan opcode)
Anmol Sethi's avatar
Anmol Sethi committed
	c.readDone = make(chan int)
Anmol Sethi's avatar
Anmol Sethi committed
	c.readBytes = make(chan []byte)
	c.readerDone = make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
	runtime.SetFinalizer(c, func(c *Conn) {
		c.Close(StatusInternalError, "websocket: connection ended up being garbage collected")
	})

Anmol Sethi's avatar
Anmol Sethi committed
	go c.writeLoop()
	go c.readLoop()
Anmol Sethi's avatar
Anmol Sethi committed
}

func (c *Conn) writeFrame(h header, p []byte) {
	b2 := marshalHeader(h)
	_, err := c.bw.Write(b2)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.Errorf("failed to write to connection: %w", err))
		return
	}

	_, err = c.bw.Write(p)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.Errorf("failed to write to connection: %w", err))
		return
	}

	if h.opcode.controlOp() {
		err := c.bw.Flush()
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(xerrors.Errorf("failed to write to connection: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeLoop() {
messageLoop:
	for {
		c.writeBytes = make(chan []byte)
		var dataType MessageType
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-c.closed:
			return
		case dataType = <-c.write:
		case control := <-c.control:
			h := header{
				fin:           true,
				opcode:        control.opcode,
				payloadLength: int64(len(control.payload)),
				masked:        c.client,
			}
			c.writeFrame(h, control.payload)
			select {
			case <-c.closed:
				return
			case c.writeDone <- struct{}{}:
			}
Anmol Sethi's avatar
Anmol Sethi committed
		}

		var firstSent bool
		for {
			select {
			case <-c.closed:
				return
			case control := <-c.control:
				h := header{
					fin:           true,
					opcode:        control.opcode,
					payloadLength: int64(len(control.payload)),
					masked:        c.client,
				}
				c.writeFrame(h, control.payload)
				select {
				case <-c.closed:
					return
				case c.writeDone <- struct{}{}:
					continue
				}
Anmol Sethi's avatar
Anmol Sethi committed
			case b, ok := <-c.writeBytes:
				h := header{
					fin:           !ok,
					opcode:        opcode(dataType),
					payloadLength: int64(len(b)),
					masked:        c.client,
				}
				if firstSent {
					h.opcode = opContinuation
				}
				firstSent = true
				c.writeFrame(h, b)
				if !ok {
					err := c.bw.Flush()
Anmol Sethi's avatar
Anmol Sethi committed
					if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
						c.close(xerrors.Errorf("failed to write to connection: %w", err))
				select {
				case <-c.closed:
					return
				case c.writeDone <- struct{}{}:
					if ok {
						continue
					} else {
						continue messageLoop
Anmol Sethi's avatar
Anmol Sethi committed
					}
				}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) handleControl(h header) {
	if h.payloadLength > maxControlFramePayload {
		c.Close(StatusProtocolError, "control frame too large")
		return
	}
Anmol Sethi's avatar
Anmol Sethi committed

	if !h.fin {
		c.Close(StatusProtocolError, "control frame cannot be fragmented")
		return
	}

Anmol Sethi's avatar
Anmol Sethi committed
	b := make([]byte, h.payloadLength)
	_, err := io.ReadFull(c.br, b)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.Errorf("failed to read control frame payload: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
		return
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	if h.masked {
		mask(h.maskKey, 0, b)
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	switch h.opcode {
	case opPing:
		c.writePong(b)
	case opPong:
	case opClose:
		if len(b) > 0 {
			ce, err := parseClosePayload(b)
			if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
				c.close(xerrors.Errorf("read invalid close payload: %w", err))
			c.Close(ce.Code, ce.Reason)
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeClose(nil, CloseError{
				Code: StatusNoStatusRcvd,
			})
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed
	default:
		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
Anmol Sethi's avatar
Anmol Sethi committed
	}
}

func (c *Conn) readLoop() {
Anmol Sethi's avatar
Anmol Sethi committed
	var indata bool
Anmol Sethi's avatar
Anmol Sethi committed
	for {
		h, err := readHeader(c.br)
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(xerrors.Errorf("failed to read header: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
			return
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if h.rsv1 || h.rsv2 || h.rsv3 {
			c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
Anmol Sethi's avatar
Anmol Sethi committed
			return
		}

		if h.opcode.controlOp() {
Anmol Sethi's avatar
Anmol Sethi committed
			c.handleControl(h)
			continue
		}

Anmol Sethi's avatar
Anmol Sethi committed
		switch h.opcode {
		case opBinary, opText:
			if !indata {
				select {
				case <-c.closed:
					return
				case c.read <- h.opcode:
				}
				indata = true
			} else {
				c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished")
				return
			}
		case opContinuation:
			if !indata {
				c.Close(StatusProtocolError, "continuation frame not after data or text frame")
				return
			}
Anmol Sethi's avatar
Anmol Sethi committed
		default:
			c.Close(StatusProtocolError, fmt.Sprintf("unknown opcode %v", h.opcode))
Anmol Sethi's avatar
Anmol Sethi committed
			return
		}

		maskPos := 0
Anmol Sethi's avatar
Anmol Sethi committed
		left := h.payloadLength
		firstRead := false
		for left > 0 || !firstRead {
Anmol Sethi's avatar
Anmol Sethi committed
			select {
			case <-c.closed:
				return
			case b := <-c.readBytes:
				if int64(len(b)) > left {
					b = b[:left]
				}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
				_, err = io.ReadFull(c.br, b)
				if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
					c.close(xerrors.Errorf("failed to read from connection: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
					return
				}
				left -= int64(len(b))
Anmol Sethi's avatar
Anmol Sethi committed
				if h.masked {
					maskPos = mask(h.maskKey, maskPos, b)
Anmol Sethi's avatar
Anmol Sethi committed
				}

Anmol Sethi's avatar
Anmol Sethi committed
				select {
				case <-c.closed:
					return
				case c.readDone <- len(b):
					firstRead = true
Anmol Sethi's avatar
Anmol Sethi committed
		}

		if h.fin {
			indata = false
			select {
			case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
				return
			case c.readerDone <- struct{}{}:
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writePong(p []byte) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opPong, p)
Anmol Sethi's avatar
Anmol Sethi committed
	return err
Anmol Sethi's avatar
Anmol Sethi committed
}

// Close closes the WebSocket connection with the given status code and reason.
// It will write a WebSocket close frame with a timeout of 5 seconds.
func (c *Conn) Close(code StatusCode, reason string) error {
	ce := CloseError{
		Code:   code,
		Reason: reason,
	}

Anmol Sethi's avatar
Anmol Sethi committed
	// This function also will not wait for a close frame from the peer like the RFC
	// wants because that makes no sense and I don't think anyone actually follows that.
	// Definitely worth seeing what popular browsers do later.
	p, err := ce.bytes()
Anmol Sethi's avatar
Anmol Sethi committed
	if err != nil {
		ce = CloseError{
			Code: StatusInternalError,
		}
		p, _ = ce.bytes()
Anmol Sethi's avatar
Anmol Sethi committed
	}

	cerr := c.writeClose(p, ce)
	if err != nil {
		return err
	}
	return cerr
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeClose(p []byte, cerr CloseError) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opClose, p)

	c.close(cerr)

	if err != nil {
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

	if cerr != c.closeErr {
		return c.closeErr
	}

	return nil
Anmol Sethi's avatar
Anmol Sethi committed

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		return c.closeErr
	case c.control <- control{
		opcode:  opcode,
		payload: p,
	}:
Anmol Sethi's avatar
Anmol Sethi committed
	case <-ctx.Done():
		c.close(xerrors.New("force closed: close frame write timed out"))
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	}

	select {
	case <-c.closed:
		return c.closeErr
	case <-c.writeDone:
		return nil
	case <-ctx.Done():
		return ctx.Err()
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Write returns a writer bounded by the context that will write
Anmol Sethi's avatar
Anmol Sethi committed
// a WebSocket data frame of type dataType to the connection.
// Ensure you close the messageWriter once you have written to entire message.
// Concurrent calls to messageWriter are ok.
func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser {
	// TODO acquire write here, move state into Conn and make messageWriter allocation free.
	return &messageWriter{
Anmol Sethi's avatar
Anmol Sethi committed
		c:        c,
Anmol Sethi's avatar
Anmol Sethi committed
		datatype: dataType,
	}
}

// messageWriter enables writing to a WebSocket connection.
// Ensure you close the messageWriter once you have written to entire message.
type messageWriter struct {
	datatype     MessageType
Anmol Sethi's avatar
Anmol Sethi committed
	ctx          context.Context
	c            *Conn
	acquiredLock bool
}
Anmol Sethi's avatar
Anmol Sethi committed

// Write writes the given bytes to the WebSocket connection.
// The frame will automatically be fragmented as appropriate
// with the buffers obtained from http.Hijacker.
// Please ensure you call Close once you have written the full message.
func (w *messageWriter) Write(p []byte) (int, error) {
	err := w.acquire()
	if err != nil {
		return 0, err
Anmol Sethi's avatar
Anmol Sethi committed
	}

	select {
	case <-w.c.closed:
		return 0, w.c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	case w.c.writeBytes <- p:
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-w.c.closed:
			return 0, w.c.closeErr
		case <-w.c.writeDone:
Anmol Sethi's avatar
Anmol Sethi committed
			return len(p), nil
		case <-w.ctx.Done():
			return 0, w.ctx.Err()
		}
Anmol Sethi's avatar
Anmol Sethi committed
	case <-w.ctx.Done():
		return 0, w.ctx.Err()
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

func (w *messageWriter) acquire() error {
Anmol Sethi's avatar
Anmol Sethi committed
	if !w.acquiredLock {
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-w.c.closed:
			return w.c.closeErr
		case w.c.write <- w.datatype:
Anmol Sethi's avatar
Anmol Sethi committed
			w.acquiredLock = true
		case <-w.ctx.Done():
			return w.ctx.Err()
		}
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return nil
}

// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
	err := w.acquire()
	if err != nil {
		return err
	}

Anmol Sethi's avatar
Anmol Sethi committed
	close(w.c.writeBytes)
	select {
	case <-w.c.closed:
		return w.c.closeErr
	case <-w.ctx.Done():
		return w.ctx.Err()
	case <-w.c.writeDone:
		return nil
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// ReadMessage will wait until there is a WebSocket data frame to read from the connection.
// It returns the type of the data, a reader to read it and also an error.
// Please use SetContext on the reader to bound the read operation.
// Your application must keep reading messages for the Conn to automatically respond to ping
// and close frames.
func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) {
	// TODO error if the reader is not done
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.readerDone:
		// The previous reader just hit a io.EOF, we handle it for users
		return c.Read(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
	case <-c.closed:
		return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr)
Anmol Sethi's avatar
Anmol Sethi committed
	case opcode := <-c.read:
		return MessageType(opcode), messageReader{
Anmol Sethi's avatar
Anmol Sethi committed
			c:   c,
		}, nil
	case <-ctx.Done():
		return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err())
	}
}

// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
	ctx context.Context
	c   *Conn
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

// Read reads as many bytes as possible into p.
func (r messageReader) Read(p []byte) (int, error) {
	n, err := r.read(p)
	if err != nil {
		// Have to return io.EOF directly for now.
		if err == io.EOF {
			return 0, io.EOF
		}
		return n, xerrors.Errorf("failed to read: %w", err)
	}
	return n, nil
}

func (r messageReader) read(p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-r.c.closed:
		return 0, r.c.closeErr
	case <-r.c.readerDone:
Anmol Sethi's avatar
Anmol Sethi committed
		return 0, io.EOF
	case r.c.readBytes <- p:
		// TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes.
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-r.c.closed:
			return 0, r.c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
		case n := <-r.c.readDone:
			return n, nil
Anmol Sethi's avatar
Anmol Sethi committed
		case <-r.ctx.Done():
			return 0, r.ctx.Err()
		}
	case <-r.ctx.Done():
		return 0, r.ctx.Err()
	}
Anmol Sethi's avatar
Anmol Sethi committed
}