Newer
Older
type control struct {
opcode opcode
payload []byte
// Pings will always be automatically responded to with pongs, you do not
// have to do anything special.
type Conn struct {
subprotocol string
br *bufio.Reader
bw *bufio.Writer
closer io.Closer
client bool
closeOnce sync.Once
closeErr error
closed chan struct{}
// Writers should send on write to begin sending
// a message and then follow that up with some data
// on writeBytes.
// Readers should receive on read to begin reading a message.
// Then send a byte slice to readBytes to read into it.
// The n of bytes read will be sent on readDone once the read into a slice is complete.
// readDone will receive 0 when EOF is reached.
read chan opcode
readBytes chan []byte
readDone chan int
readerDone chan struct{}
}
func (c *Conn) close(err error) {
if err != nil {
err = xerrors.Errorf("websocket: connection broken: %w", err)
c.closeErr = err
cerr := c.closer.Close()
if c.closeErr == nil {
c.closeErr = cerr
}
close(c.closed)
})
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
return c.subprotocol
}
func (c *Conn) init() {
c.closed = make(chan struct{})
c.control = make(chan control)
c.writeDone = make(chan struct{})
runtime.SetFinalizer(c, func(c *Conn) {
c.Close(StatusInternalError, "websocket: connection ended up being garbage collected")
})
func (c *Conn) writeFrame(h header, p []byte) {
b2 := marshalHeader(h)
_, err := c.bw.Write(b2)
if err != nil {
c.close(xerrors.Errorf("failed to write to connection: %w", err))
return
}
_, err = c.bw.Write(p)
if err != nil {
c.close(xerrors.Errorf("failed to write to connection: %w", err))
return
}
if h.opcode.controlOp() {
err := c.bw.Flush()
if err != nil {
c.close(xerrors.Errorf("failed to write to connection: %w", err))
func (c *Conn) writeLoop() {
messageLoop:
for {
c.writeBytes = make(chan []byte)
case dataType = <-c.write:
case control := <-c.control:
h := header{
fin: true,
opcode: control.opcode,
payloadLength: int64(len(control.payload)),
masked: c.client,
}
c.writeFrame(h, control.payload)
select {
case <-c.closed:
return
case c.writeDone <- struct{}{}:
}
}
var firstSent bool
for {
select {
case <-c.closed:
return
case control := <-c.control:
h := header{
fin: true,
opcode: control.opcode,
payloadLength: int64(len(control.payload)),
masked: c.client,
}
c.writeFrame(h, control.payload)
select {
case <-c.closed:
return
case c.writeDone <- struct{}{}:
continue
}
h := header{
fin: !ok,
opcode: opcode(dataType),
payloadLength: int64(len(b)),
masked: c.client,
}
if firstSent {
h.opcode = opContinuation
}
firstSent = true
c.close(xerrors.Errorf("failed to write to connection: %w", err))
select {
case <-c.closed:
return
case c.writeDone <- struct{}{}:
if ok {
continue
} else {
continue messageLoop
func (c *Conn) handleControl(h header) {
if h.payloadLength > maxControlFramePayload {
c.Close(StatusProtocolError, "control frame too large")
return
}
if !h.fin {
c.Close(StatusProtocolError, "control frame cannot be fragmented")
return
}
b := make([]byte, h.payloadLength)
_, err := io.ReadFull(c.br, b)
if err != nil {
c.close(xerrors.Errorf("failed to read control frame payload: %w", err))
switch h.opcode {
case opPing:
c.writePong(b)
case opPong:
case opClose:
c.close(xerrors.Errorf("read invalid close payload: %w", err))
c.writeClose(nil, CloseError{
Code: StatusNoStatusRcvd,
})
default:
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
c.close(xerrors.Errorf("failed to read header: %w", err))
if h.rsv1 || h.rsv2 || h.rsv3 {
c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
if !indata {
select {
case <-c.closed:
return
case c.read <- h.opcode:
}
indata = true
} else {
c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished")
return
}
case opContinuation:
if !indata {
c.Close(StatusProtocolError, "continuation frame not after data or text frame")
return
}
c.Close(StatusProtocolError, fmt.Sprintf("unknown opcode %v", h.opcode))
firstRead := false
for left > 0 || !firstRead {
select {
case <-c.closed:
return
case b := <-c.readBytes:
if int64(len(b)) > left {
b = b[:left]
}
_, err = io.ReadFull(c.br, b)
if err != nil {
c.close(xerrors.Errorf("failed to read from connection: %w", err))
if h.masked {
maskPos = mask(h.maskKey, maskPos, b)
select {
case <-c.closed:
return
case c.readDone <- len(b):
}
if h.fin {
indata = false
select {
case <-c.closed:
func (c *Conn) writePong(p []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.writeControl(ctx, opPong, p)
}
// Close closes the WebSocket connection with the given status code and reason.
// It will write a WebSocket close frame with a timeout of 5 seconds.
func (c *Conn) Close(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
// This function also will not wait for a close frame from the peer like the RFC
// wants because that makes no sense and I don't think anyone actually follows that.
// Definitely worth seeing what popular browsers do later.
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytes()
func (c *Conn) writeClose(p []byte, cerr CloseError) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.writeControl(ctx, opClose, p)
c.close(cerr)
if cerr != c.closeErr {
return c.closeErr
}
return nil
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
case c.control <- control{
opcode: opcode,
payload: p,
}:
case <-ctx.Done():
c.close(xerrors.New("force closed: close frame write timed out"))
case <-c.writeDone:
return nil
case <-ctx.Done():
return ctx.Err()
// Write returns a writer bounded by the context that will write
// a WebSocket data frame of type dataType to the connection.
// Ensure you close the messageWriter once you have written to entire message.
// Concurrent calls to messageWriter are ok.
func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser {
// TODO acquire write here, move state into Conn and make messageWriter allocation free.
// messageWriter enables writing to a WebSocket connection.
// Ensure you close the messageWriter once you have written to entire message.
type messageWriter struct {
ctx context.Context
c *Conn
acquiredLock bool
}
// Write writes the given bytes to the WebSocket connection.
// The frame will automatically be fragmented as appropriate
// with the buffers obtained from http.Hijacker.
// Please ensure you call Close once you have written the full message.
func (w *messageWriter) Write(p []byte) (int, error) {
err := w.acquire()
if err != nil {
return 0, err
return len(p), nil
case <-w.ctx.Done():
return 0, w.ctx.Err()
}
w.acquiredLock = true
case <-w.ctx.Done():
return w.ctx.Err()
}
return nil
}
// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
err := w.acquire()
if err != nil {
return err
}
case <-w.ctx.Done():
return w.ctx.Err()
case <-w.c.writeDone:
return nil
}
// ReadMessage will wait until there is a WebSocket data frame to read from the connection.
// It returns the type of the data, a reader to read it and also an error.
// Please use SetContext on the reader to bound the read operation.
// Your application must keep reading messages for the Conn to automatically respond to ping
// and close frames.
func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) {
// TODO error if the reader is not done
case <-c.readerDone:
// The previous reader just hit a io.EOF, we handle it for users
return c.Read(ctx)
return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr)
c: c,
}, nil
case <-ctx.Done():
return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err())
}
}
// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
ctx context.Context
c *Conn
func (r messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
if err != nil {
// Have to return io.EOF directly for now.
if err == io.EOF {
return 0, io.EOF
}
return n, xerrors.Errorf("failed to read: %w", err)
}
return n, nil
}
func (r messageReader) read(p []byte) (int, error) {
// TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes.
case <-r.ctx.Done():
return 0, r.ctx.Err()
}
case <-r.ctx.Done():
return 0, r.ctx.Err()
}