good morning!!!!

Skip to content
Snippets Groups Projects
websocket.go 11.7 KiB
Newer Older
package websocket
Anmol Sethi's avatar
Anmol Sethi committed

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bufio"
Anmol Sethi's avatar
Anmol Sethi committed
	"context"
Anmol Sethi's avatar
Anmol Sethi committed
	"fmt"
	"io"
Anmol Sethi's avatar
Anmol Sethi committed
	"runtime"
Anmol Sethi's avatar
Anmol Sethi committed
	"sync"
	"time"

	"golang.org/x/xerrors"
Anmol Sethi's avatar
Anmol Sethi committed
)

type control struct {
	opcode  opcode
	payload []byte
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Conn represents a WebSocket connection.
Anmol Sethi's avatar
Anmol Sethi committed
// Pings will always be automatically responded to with pongs, you do not
// have to do anything special.
type Conn struct {
	subprotocol string
	br          *bufio.Reader
Anmol Sethi's avatar
Anmol Sethi committed
	// TODO Cannot use bufio writer because for compression we need to know how much is buffered and compress it if large.
	bw     *bufio.Writer
	closer io.Closer
	client bool
Anmol Sethi's avatar
Anmol Sethi committed

	closeOnce sync.Once
	closeErr  error
	closed    chan struct{}

	// Writers should send on write to begin sending
	// a message and then follow that up with some data
	// on writeBytes.
	write      chan DataType
	control    chan control
Anmol Sethi's avatar
Anmol Sethi committed
	writeBytes chan []byte
	writeDone  chan struct{}
Anmol Sethi's avatar
Anmol Sethi committed

	// Readers should receive on read to begin reading a message.
	// Then send a byte slice to readBytes to read into it.
Anmol Sethi's avatar
Anmol Sethi committed
	// The n of bytes read will be sent on readDone once the read into a slice is complete.
	// readDone will receive 0 when EOF is reached.
Anmol Sethi's avatar
Anmol Sethi committed
	read      chan opcode
	readBytes chan []byte
Anmol Sethi's avatar
Anmol Sethi committed
	readDone  chan int
Anmol Sethi's avatar
Anmol Sethi committed
}

func (c *Conn) getCloseErr() error {
Anmol Sethi's avatar
Anmol Sethi committed
	if c.closeErr != nil {
		return c.closeErr
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
	return nil
Anmol Sethi's avatar
Anmol Sethi committed
}

func (c *Conn) close(err error) {
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		err = xerrors.Errorf("websocket: connection broken: %w", err)
Anmol Sethi's avatar
Anmol Sethi committed
	}

	c.closeOnce.Do(func() {
Anmol Sethi's avatar
Anmol Sethi committed
		runtime.SetFinalizer(c, nil)

Anmol Sethi's avatar
Anmol Sethi committed
		c.closeErr = err

		cerr := c.closer.Close()
		if c.closeErr == nil {
			c.closeErr = cerr
		}

		close(c.closed)
	})
}
Anmol Sethi's avatar
Anmol Sethi committed

// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
Anmol Sethi's avatar
Anmol Sethi committed
	return c.subprotocol
}

func (c *Conn) init() {
	c.closed = make(chan struct{})
	c.write = make(chan DataType)
	c.control = make(chan control)
	c.writeDone = make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
	c.read = make(chan opcode)
Anmol Sethi's avatar
Anmol Sethi committed
	c.readDone = make(chan int)
Anmol Sethi's avatar
Anmol Sethi committed
	c.readBytes = make(chan []byte)
Anmol Sethi's avatar
Anmol Sethi committed
	runtime.SetFinalizer(c, func(c *Conn) {
		c.Close(StatusInternalError, "websocket: connection ended up being garbage collected")
	})

Anmol Sethi's avatar
Anmol Sethi committed
	go c.writeLoop()
	go c.readLoop()
Anmol Sethi's avatar
Anmol Sethi committed
}

func (c *Conn) writeFrame(h header, p []byte) {
	b2 := marshalHeader(h)
	_, err := c.bw.Write(b2)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.Errorf("failed to write to connection: %w", err))
		return
	}

	_, err = c.bw.Write(p)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.Errorf("failed to write to connection: %w", err))
		return
	}

	if h.opcode.controlOp() {
		err := c.bw.Flush()
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(xerrors.Errorf("failed to write to connection: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeLoop() {
messageLoop:
	for {
		c.writeBytes = make(chan []byte)

		var dataType DataType
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-c.closed:
			return
		case dataType = <-c.write:
		case control := <-c.control:
			h := header{
				fin:           true,
				opcode:        control.opcode,
				payloadLength: int64(len(control.payload)),
				masked:        c.client,
			}
			c.writeFrame(h, control.payload)
			select {
			case <-c.closed:
				return
			case c.writeDone <- struct{}{}:
			}
Anmol Sethi's avatar
Anmol Sethi committed
		}

		var firstSent bool
		for {
			select {
			case <-c.closed:
				return
			case control := <-c.control:
				h := header{
					fin:           true,
					opcode:        control.opcode,
					payloadLength: int64(len(control.payload)),
					masked:        c.client,
				}
				c.writeFrame(h, control.payload)
				c.writeDone <- struct{}{}
				continue
Anmol Sethi's avatar
Anmol Sethi committed
			case b, ok := <-c.writeBytes:
				h := header{
					fin:           !ok,
					opcode:        opcode(dataType),
					payloadLength: int64(len(b)),
					masked:        c.client,
				}
				if firstSent {
					h.opcode = opContinuation
				}
				firstSent = true
				c.writeFrame(h, b)
				if !ok {
					err := c.bw.Flush()
Anmol Sethi's avatar
Anmol Sethi committed
					if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
						c.close(xerrors.Errorf("failed to write to connection: %w", err))
				select {
				case <-c.closed:
					return
				case c.writeDone <- struct{}{}:
					if ok {
						continue
					} else {
						continue messageLoop
Anmol Sethi's avatar
Anmol Sethi committed
					}
				}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) handleControl(h header) {
	if h.payloadLength > maxControlFramePayload {
		c.Close(StatusProtocolError, "control frame too large")
		return
	}
Anmol Sethi's avatar
Anmol Sethi committed

	if !h.fin {
		c.Close(StatusProtocolError, "control frame cannot be fragmented")
		return
	}

Anmol Sethi's avatar
Anmol Sethi committed
	b := make([]byte, h.payloadLength)
	_, err := io.ReadFull(c.br, b)
	if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
		c.close(xerrors.Errorf("failed to read control frame payload: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
		return
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	if h.masked {
		mask(h.maskKey, 0, b)
	}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	switch h.opcode {
	case opPing:
		c.writePong(b)
	case opPong:
	case opClose:
		if len(b) > 0 {
			code, reason, err := parseClosePayload(b)
			if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
				c.close(xerrors.Errorf("read invalid close payload: %w", err))
				return
			}
			c.Close(code, reason)
		} else {
Anmol Sethi's avatar
Anmol Sethi committed
			c.writeClose(nil, CloseError{
				Code: StatusNoStatusRcvd,
			})
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed
	default:
		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
Anmol Sethi's avatar
Anmol Sethi committed
	}
}

func (c *Conn) readLoop() {
Anmol Sethi's avatar
Anmol Sethi committed
	var indata bool
Anmol Sethi's avatar
Anmol Sethi committed
	for {
		h, err := readHeader(c.br)
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(xerrors.Errorf("failed to read header: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
			return
		}

Anmol Sethi's avatar
Anmol Sethi committed
		if h.rsv1 || h.rsv2 || h.rsv3 {
			c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
Anmol Sethi's avatar
Anmol Sethi committed
			return
		}

		if h.opcode.controlOp() {
Anmol Sethi's avatar
Anmol Sethi committed
			c.handleControl(h)
			continue
		}

Anmol Sethi's avatar
Anmol Sethi committed
		switch h.opcode {
		case opBinary, opText:
			if !indata {
				select {
				case <-c.closed:
					return
				case c.read <- h.opcode:
				}
				indata = true
			} else {
				c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished")
				return
			}
		case opContinuation:
			if !indata {
				c.Close(StatusProtocolError, "continuation frame not after data or text frame")
				return
			}
Anmol Sethi's avatar
Anmol Sethi committed
		default:
			// TODO send back protocol violation message or figure out what RFC wants.
Anmol Sethi's avatar
Anmol Sethi committed
			c.close(xerrors.Errorf("unexpected opcode in header: %#v", h))
			return
		}

		maskPos := 0
Anmol Sethi's avatar
Anmol Sethi committed
		left := h.payloadLength
		firstRead := false
		for left > 0 || !firstRead {
Anmol Sethi's avatar
Anmol Sethi committed
			select {
			case <-c.closed:
				return
			case b := <-c.readBytes:
				if int64(len(b)) > left {
					b = b[:left]
				}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
				_, err = io.ReadFull(c.br, b)
				if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
					c.close(xerrors.Errorf("failed to read from connection: %w", err))
Anmol Sethi's avatar
Anmol Sethi committed
					return
				}
				left -= int64(len(b))
Anmol Sethi's avatar
Anmol Sethi committed
				if h.masked {
					maskPos = mask(h.maskKey, maskPos, b)
Anmol Sethi's avatar
Anmol Sethi committed
				}

Anmol Sethi's avatar
Anmol Sethi committed
				select {
				case <-c.closed:
					return
				case c.readDone <- len(b):
					firstRead = true
Anmol Sethi's avatar
Anmol Sethi committed
		}

		if h.fin {
			indata = false
			select {
			case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
				return
Anmol Sethi's avatar
Anmol Sethi committed
			case c.readDone <- 0:
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writePong(p []byte) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opPong, p)
Anmol Sethi's avatar
Anmol Sethi committed
	return err
Anmol Sethi's avatar
Anmol Sethi committed
}

// Close closes the WebSocket connection with the given status code and reason.
// It will write a WebSocket close frame with a timeout of 5 seconds.
func (c *Conn) Close(code StatusCode, reason string) error {
Anmol Sethi's avatar
Anmol Sethi committed
	// This function also will not wait for a close frame from the peer like the RFC
	// wants because that makes no sense and I don't think anyone actually follows that.
	// Definitely worth seeing what popular browsers do later.
Anmol Sethi's avatar
Anmol Sethi committed
	p, err := closePayload(code, reason)
	if err != nil {
		p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code))
	}

Anmol Sethi's avatar
Anmol Sethi committed
	err2 := c.writeClose(p, CloseError{
		Code:   code,
		Reason: reason,
	})
	if err != nil {
		return err
	}
Anmol Sethi's avatar
Anmol Sethi committed
	return err2
}
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) writeClose(p []byte, cerr CloseError) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opClose, p)

	c.close(cerr)

	if err != nil {
		return err
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed

	if cerr != c.closeErr {
		return c.closeErr
	}

	return nil
Anmol Sethi's avatar
Anmol Sethi committed

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		return c.getCloseErr()
	case c.control <- control{
		opcode:  opcode,
		payload: p,
	}:
Anmol Sethi's avatar
Anmol Sethi committed
	case <-ctx.Done():
		c.close(xerrors.New("force closed: close frame write timed out"))
Anmol Sethi's avatar
Anmol Sethi committed
		return c.getCloseErr()
Anmol Sethi's avatar
Anmol Sethi committed
	}

	select {
	case <-c.closed:
Anmol Sethi's avatar
Anmol Sethi committed
		return c.getCloseErr()
	case <-c.writeDone:
		return nil
	case <-ctx.Done():
		return ctx.Err()
Anmol Sethi's avatar
Anmol Sethi committed
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// Write returns a writer bounded by the context that will write
Anmol Sethi's avatar
Anmol Sethi committed
// a WebSocket data frame of type dataType to the connection.
// Ensure you close the messageWriter once you have written to entire message.
// Concurrent calls to messageWriter are ok.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser {
	return &messageWriter{
Anmol Sethi's avatar
Anmol Sethi committed
		c:        c,
Anmol Sethi's avatar
Anmol Sethi committed
		datatype: dataType,
	}
}

// messageWriter enables writing to a WebSocket connection.
// Ensure you close the messageWriter once you have written to entire message.
type messageWriter struct {
	datatype     DataType
Anmol Sethi's avatar
Anmol Sethi committed
	ctx          context.Context
	c            *Conn
	acquiredLock bool
	sentFirst    bool

	done chan struct{}
}
Anmol Sethi's avatar
Anmol Sethi committed

// Write writes the given bytes to the WebSocket connection.
// The frame will automatically be fragmented as appropriate
// with the buffers obtained from http.Hijacker.
// Please ensure you call Close once you have written the full message.
func (w *messageWriter) Write(p []byte) (int, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	if !w.acquiredLock {
		select {
		case <-w.c.closed:
			return 0, w.c.getCloseErr()
		case w.c.write <- w.datatype:
Anmol Sethi's avatar
Anmol Sethi committed
			w.acquiredLock = true
		case <-w.ctx.Done():
			return 0, w.ctx.Err()
		}
	}

	select {
	case <-w.c.closed:
		return 0, w.c.getCloseErr()
	case w.c.writeBytes <- p:
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-w.c.closed:
			return 0, w.c.getCloseErr()
		case <-w.c.writeDone:
Anmol Sethi's avatar
Anmol Sethi committed
			return len(p), nil
		case <-w.ctx.Done():
			return 0, w.ctx.Err()
		}
Anmol Sethi's avatar
Anmol Sethi committed
	case <-w.ctx.Done():
		return 0, w.ctx.Err()
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
Anmol Sethi's avatar
Anmol Sethi committed
	if !w.acquiredLock {
Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case <-w.c.closed:
			return w.c.getCloseErr()
		case w.c.write <- w.datatype:
Anmol Sethi's avatar
Anmol Sethi committed
			w.acquiredLock = true
		case <-w.ctx.Done():
			return w.ctx.Err()
		}
Anmol Sethi's avatar
Anmol Sethi committed
	}
	close(w.c.writeBytes)
	select {
	case <-w.c.closed:
		return w.c.getCloseErr()
	case <-w.ctx.Done():
		return w.ctx.Err()
	case <-w.c.writeDone:
		return nil
	}
Anmol Sethi's avatar
Anmol Sethi committed
}

Anmol Sethi's avatar
Anmol Sethi committed
// ReadMessage will wait until there is a WebSocket data frame to read from the connection.
// It returns the type of the data, a reader to read it and also an error.
// Please use SetContext on the reader to bound the read operation.
// Your application must keep reading messages for the Conn to automatically respond to ping
// and close frames.
Anmol Sethi's avatar
Anmol Sethi committed
func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-c.closed:
		return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr())
	case opcode := <-c.read:
		return DataType(opcode), &messageReader{
			ctx: ctx,
Anmol Sethi's avatar
Anmol Sethi committed
			c:   c,
		}, nil
	case <-ctx.Done():
		return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err())
	}
}

// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
	ctx context.Context
	c   *Conn
Anmol Sethi's avatar
Anmol Sethi committed
}
Anmol Sethi's avatar
Anmol Sethi committed

// SetContext bounds the read operation to the ctx.
// By default, the context is the one passed to conn.ReadMessage.
// You still almost always want a separate context for reading the message though.
func (r *messageReader) SetContext(ctx context.Context) {
Anmol Sethi's avatar
Anmol Sethi committed
	r.ctx = ctx
Anmol Sethi's avatar
Anmol Sethi committed
}

// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (n int, err error) {
Anmol Sethi's avatar
Anmol Sethi committed
	select {
	case <-r.c.closed:
		return 0, r.c.getCloseErr()
	case <-r.c.readDone:
		return 0, io.EOF
	case r.c.readBytes <- p:
		select {
		case <-r.c.closed:
			return 0, r.c.getCloseErr()
Anmol Sethi's avatar
Anmol Sethi committed
		case n := <-r.c.readDone:
			return n, nil
Anmol Sethi's avatar
Anmol Sethi committed
		case <-r.ctx.Done():
			return 0, r.ctx.Err()
		}
	case <-r.ctx.Done():
		return 0, r.ctx.Err()
	}
Anmol Sethi's avatar
Anmol Sethi committed
}