diff --git a/accept.go b/accept.go index 2cf1dc017a11a56473e7da7c8184cc127a637109..207ecc74ad7a81f6f96f3198c08afb83a66c4886 100644 --- a/accept.go +++ b/accept.go @@ -75,6 +75,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to WebSocket. +// // Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { @@ -132,6 +133,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, closer: netConn, } c.init() + // TODO document. + c.Context(r.Context()) return c, nil } diff --git a/ci/lint/entrypoint.sh b/ci/lint/entrypoint.sh index c539495ee30efb24850810379991a1fd4d4248fe..09c3168322beecaa4a665a3f3e0d7180bea42432 100755 --- a/ci/lint/entrypoint.sh +++ b/ci/lint/entrypoint.sh @@ -7,5 +7,5 @@ source ci/lib.sh || exit 1 shellcheck ./**/*.sh ) -go vet -composites=false ./... +go vet -composites=false -lostcancel=false ./... go run golang.org/x/lint/golint -set_exit_status ./... diff --git a/export_test.go b/export_test.go index d180e119cac2fe896220300d3c4f71afcd13f63e..465ba9eb8e4a52cab5a0220d5f5f7e7facdeac79 100644 --- a/export_test.go +++ b/export_test.go @@ -14,5 +14,5 @@ import ( // exceeds the buffer size which is 4096 right now as then an extra syscall // will be necessary to complete the message. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - return c.writeSingleFrame(ctx, opcode(typ), p) + return c.writeCompleteMessage(ctx, opcode(typ), p) } diff --git a/websocket.go b/websocket.go index c0737cad6120639c621b707dbc5ac2196c31b09f..6e35281a85518c9006b2442afc262f3133881e39 100644 --- a/websocket.go +++ b/websocket.go @@ -8,17 +8,11 @@ import ( "os" "runtime" "sync" - "sync/atomic" "time" "golang.org/x/xerrors" ) -type frame struct { - opcode opcode - payload []byte -} - // Conn represents a WebSocket connection. // All methods except Reader can be used concurrently. // Please be sure to call Close on the connection when you @@ -34,31 +28,41 @@ type Conn struct { closeErr error closed chan struct{} - // Writers should send on write to begin sending - // a message and then follow that up with some data - // on writeBytes. - // Send on control to write a control message. - // writeDone will be sent back when the message is written - // Send on writeFlush to flush the message and wait for a - // ping on writeDone. - // writeDone will be closed if the data message write errors. - write chan MessageType - control chan frame - fastWrite chan frame - writeBytes chan []byte - writeDone chan struct{} - writeFlush chan struct{} - - // Readers should receive on read to begin reading a message. - // Then send a byte slice to readBytes to read into it. - // The n of bytes read will be sent on readDone once the read into a slice is complete. - // readDone will be closed if the read fails. - // activeReader will be set to 0 on io.EOF. - activeReader int64 - inMsg bool - read chan opcode - readBytes chan []byte - readDone chan int + writeDataLock chan struct{} + writeFrameLock chan struct{} + + readData chan header + readDone chan struct{} + + setReadTimeout chan context.Context + setWriteTimeout chan context.Context + setConnContext chan context.Context + getConnContext chan context.Context +} + +// Context returns a context derived from parent that will be cancelled +// when the connection is closed. +// If the parent context is cancelled, the connection will be closed. +// +// This is an experimental API meaning it may be remove in the future. +// Please let me know how you feel about it. +func (c *Conn) Context(parent context.Context) context.Context { + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case c.setConnContext <- parent: + } + + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case ctx := <-c.getConnContext: + return ctx + } } func (c *Conn) close(err error) { @@ -85,124 +89,110 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) - c.write = make(chan MessageType) - c.control = make(chan frame) - c.fastWrite = make(chan frame) - c.writeBytes = make(chan []byte) - c.writeDone = make(chan struct{}) - c.writeFlush = make(chan struct{}) + c.writeDataLock = make(chan struct{}, 1) + c.writeFrameLock = make(chan struct{}, 1) + + c.readData = make(chan header) + c.readDone = make(chan struct{}) - c.read = make(chan opcode) - c.readBytes = make(chan []byte) - c.readDone = make(chan int) + c.setReadTimeout = make(chan context.Context) + c.setWriteTimeout = make(chan context.Context) + c.setConnContext = make(chan context.Context) + c.getConnContext = make(chan context.Context) runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) }) - go c.writeLoop() + go c.timeoutLoop() go c.readLoop() } // We never mask inside here because our mask key is always 0,0,0,0. // See comment on secWebSocketKey. -func (c *Conn) writeFrame(h header, p []byte) { +func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error { + err := c.acquireLock(ctx, c.writeFrameLock) + if err != nil { + return err + } + defer c.releaseLock(c.writeFrameLock) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closed: + return c.closeErr + case c.setWriteTimeout <- ctx: + } + defer func() { + // We have to remove the write timeout, even if ctx is cancelled. + select { + case <-c.closed: + return + case c.setWriteTimeout <- context.Background(): + } + }() + + h.masked = c.client + h.payloadLength = int64(len(p)) + b2 := marshalHeader(h) - _, err := c.bw.Write(b2) + _, err = c.bw.Write(b2) if err != nil { c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return + return c.closeErr } - _, err = c.bw.Write(p) if err != nil { c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return + return c.closeErr + } if h.fin { err := c.bw.Flush() if err != nil { c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return + return c.closeErr } } -} -func (c *Conn) writeLoopFastWrite(frame frame) { - h := header{ - fin: true, - opcode: frame.opcode, - payloadLength: int64(len(frame.payload)), - masked: c.client, - } - c.writeFrame(h, frame.payload) - select { - case <-c.closed: - case c.writeDone <- struct{}{}: - } + return nil } -func (c *Conn) writeLoop() { - defer close(c.writeDone) +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + parentCtx := context.Background() + cancelCtx := func() {} + defer func() { + // We do not defer cancelCtx because its value can change. + cancelCtx() + }() -messageLoop: for { - var dataType MessageType select { case <-c.closed: return - case control := <-c.control: - c.writeLoopFastWrite(control) - continue - case frame := <-c.fastWrite: - c.writeLoopFastWrite(frame) - continue - case dataType = <-c.write: - } - - var firstSent bool - for { + case readCtx = <-c.setWriteTimeout: + case writeCtx = <-c.setReadTimeout: + case <-readCtx.Done(): + c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err())) + case <-writeCtx.Done(): + c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err())) + case <-parentCtx.Done(): + c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err())) + return + case parentCtx = <-c.setConnContext: + var ctx context.Context + ctx, cancelCtx = context.WithCancel(parentCtx) select { case <-c.closed: return - case control := <-c.control: - c.writeLoopFastWrite(control) - case b := <-c.writeBytes: - h := header{ - fin: false, - opcode: opcode(dataType), - payloadLength: int64(len(b)), - masked: c.client, - } - - if firstSent { - h.opcode = opContinuation - } - firstSent = true - - c.writeFrame(h, b) - - select { - case <-c.closed: - return - case c.writeDone <- struct{}{}: - } - case <-c.writeFlush: - h := header{ - fin: true, - opcode: opcode(dataType), - payloadLength: 0, - masked: c.client, - } - - if firstSent { - h.opcode = opContinuation - } - - c.writeFrame(h, nil) - - continue messageLoop + case <-parentCtx.Done(): + c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err())) + return + case c.getConnContext <- ctx: } } } @@ -250,19 +240,20 @@ func (c *Conn) handleControl(h header) { } } -func (c *Conn) readLoop() { - defer close(c.readDone) - +func (c *Conn) readTillData() (header, error) { for { h, err := readHeader(c.br) if err != nil { - c.close(xerrors.Errorf("failed to read header: %w", err)) - return + return header{}, xerrors.Errorf("failed to read header: %w", err) } if h.rsv1 || h.rsv2 || h.rsv3 { - c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) - return + ce := CloseError{ + Code: StatusProtocolError, + Reason: fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3), + } + c.Close(ce.Code, ce.Reason) + return header{}, ce } if h.opcode.controlOp() { @@ -271,84 +262,46 @@ func (c *Conn) readLoop() { } switch h.opcode { - case opBinary, opText: - if c.inMsg { - c.Close(StatusProtocolError, "cannot read new data frame when previous frame is not finished") - return - } - - select { - case <-c.closed: - return - case c.read <- h.opcode: - c.inMsg = true - } - case opContinuation: - if !c.inMsg { - c.Close(StatusProtocolError, "continuation frame not after data or text frame") - return - } + case opBinary, opText, opContinuation: + return h, nil default: - c.Close(StatusProtocolError, fmt.Sprintf("unknown opcode %v", h.opcode)) - return + ce := CloseError{ + Code: StatusProtocolError, + Reason: fmt.Sprintf("unknown opcode %v", h.opcode), + } + c.Close(ce.Code, ce.Reason) + return header{}, ce } + } +} - err = c.dataReadLoop(h) +func (c *Conn) readLoop() { + for { + h, err := c.readTillData() if err != nil { - c.close(xerrors.Errorf("failed to read from connection: %w", err)) + c.close(err) return } - } -} -func (c *Conn) dataReadLoop(h header) error { - maskPos := 0 - left := h.payloadLength - firstReadDone := false - for left > 0 || !firstReadDone { select { case <-c.closed: - return c.closeErr - case b := <-c.readBytes: - if int64(len(b)) > left { - b = b[:left] - } - - _, err := io.ReadFull(c.br, b) - if err != nil { - return xerrors.Errorf("failed to read from connection: %w", err) - } - left -= int64(len(b)) - - if h.masked { - maskPos = fastXOR(h.maskKey, maskPos, b) - } - - // Must set this before we signal the read is done. - // The reader will use this to return io.EOF and - // c.Read will use it to check if the reader has been completed. - if left == 0 && h.fin { - atomic.StoreInt64(&c.activeReader, 0) - c.inMsg = false - } + return + case c.readData <- h: + } - select { - case <-c.closed: - return c.closeErr - case c.readDone <- len(b): - firstReadDone = true - } + select { + case <-c.closed: + return + case <-c.readDone: } } - - return nil } func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeSingleFrame(ctx, opPong, p) + err := c.writeCompleteMessage(ctx, opPong, p) return err } @@ -393,7 +346,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeSingleFrame(ctx, opClose, p) + err := c.writeCompleteMessage(ctx, opClose, p) c.close(cerr) @@ -408,33 +361,40 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { return nil } -func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) error { - ch := c.fastWrite - if opcode.controlOp() { - ch = c.control - } +func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { select { - case <-c.closed: - return c.closeErr - case ch <- frame{ - opcode: opcode, - payload: p, - }: case <-ctx.Done(): - c.close(xerrors.Errorf("control frame write timed out: %w", ctx.Err())) return ctx.Err() - } - - select { case <-c.closed: return c.closeErr - case <-c.writeDone: + case lock <- struct{}{}: return nil - case <-ctx.Done(): - return ctx.Err() } } +func (c *Conn) releaseLock(lock chan struct{}) { + <-lock +} + +func (c *Conn) writeCompleteMessage(ctx context.Context, opcode opcode, p []byte) error { + if !opcode.controlOp() { + err := c.acquireLock(ctx, c.writeDataLock) + if err != nil { + return err + } + defer c.releaseLock(c.writeDataLock) + } + + err := c.writeFrame(ctx, header{ + fin: true, + opcode: opcode, + }, p) + if err != nil { + return xerrors.Errorf("failed to write frame: %v", err) + } + return nil +} + // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // @@ -451,27 +411,27 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - select { - case <-c.closed: - return nil, c.closeErr - case <-ctx.Done(): - return nil, ctx.Err() - case c.write <- typ: - return messageWriter{ - ctx: ctx, - c: c, - }, nil + err := c.acquireLock(ctx, c.writeDataLock) + if err != nil { + return nil, err } + return &messageWriter{ + ctx: ctx, + opcode: opcode(typ), + c: c, + }, nil } // messageWriter enables writing to a WebSocket connection. type messageWriter struct { - ctx context.Context - c *Conn + ctx context.Context + opcode opcode + c *Conn + closed bool } // Write writes the given bytes to the WebSocket connection. -func (w messageWriter) Write(p []byte) (int, error) { +func (w *messageWriter) Write(p []byte) (int, error) { n, err := w.write(p) if err != nil { return n, xerrors.Errorf("failed to write: %w", err) @@ -479,31 +439,23 @@ func (w messageWriter) Write(p []byte) (int, error) { return n, nil } -func (w messageWriter) write(p []byte) (int, error) { - select { - case <-w.c.closed: - return 0, w.c.closeErr - case w.c.writeBytes <- p: - select { - case <-w.ctx.Done(): - w.c.close(xerrors.Errorf("data write timed out: %w", w.ctx.Err())) - // Wait for writeLoop to complete so we know p is done with. - <-w.c.writeDone - return 0, w.ctx.Err() - case _, ok := <-w.c.writeDone: - if !ok { - return 0, w.c.closeErr - } - return len(p), nil - } - case <-w.ctx.Done(): - return 0, w.ctx.Err() +func (w *messageWriter) write(p []byte) (int, error) { + if w.closed { + return 0, xerrors.Errorf("cannot use closed writer") + } + err := w.c.writeFrame(w.ctx, header{ + opcode: w.opcode, + }, p) + if err != nil { + return 0, err } + w.opcode = opContinuation + return len(p), nil } // Close flushes the frame to the connection. // This must be called for every messageWriter. -func (w messageWriter) Close() error { +func (w *messageWriter) Close() error { err := w.close() if err != nil { return xerrors.Errorf("failed to close writer: %w", err) @@ -511,15 +463,22 @@ func (w messageWriter) Close() error { return nil } -func (w messageWriter) close() error { - select { - case <-w.c.closed: - return w.c.closeErr - case <-w.ctx.Done(): - return w.ctx.Err() - case w.c.writeFlush <- struct{}{}: - return nil +func (w *messageWriter) close() error { + if w.closed { + return xerrors.Errorf("cannot use closed writer") } + w.closed = true + + err := w.c.writeFrame(w.ctx, header{ + fin: true, + opcode: w.opcode, + }, nil) + if err != nil { + return err + } + + w.c.releaseLock(w.c.writeDataLock) + return nil } // Reader will wait until there is a WebSocket data message to read from the connection. @@ -542,49 +501,70 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { } func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { - // If the next read yields io.EOF we are good to go. - r := messageReader{ - ctx: ctx, - c: c, - } - _, err := r.Read(nil) - if err == nil { - return 0, nil, xerrors.New("previous message not fully read") - } - if !xerrors.Is(err, io.EOF) { - return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err) - } + // if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { + // // If the next read yields io.EOF we are good to go. + // r := messageReader{ + // ctx: ctx, + // c: c, + // } + // _, err := r.Read(nil) + // if err == nil { + // return 0, nil, xerrors.New("previous message not fully read") + // } + // if !xerrors.Is(err, io.EOF) { + // return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err) + // } + // + // atomic.StoreInt64(&c.activeReader, 1) + // } - atomic.StoreInt64(&c.activeReader, 1) + select { + case <-c.closed: + return 0, nil, c.closeErr + case <-ctx.Done(): + return 0, nil, ctx.Err() + case c.setReadTimeout <- ctx: } select { case <-c.closed: return 0, nil, c.closeErr - case opcode := <-c.read: - return MessageType(opcode), messageReader{ - ctx: ctx, - c: c, - }, nil case <-ctx.Done(): return 0, nil, ctx.Err() + case h := <-c.readData: + if h.opcode == opContinuation { + if h.fin && h.payloadLength == 0 { + select { + case <-c.closed: + return 0, nil, c.closeErr + case c.readDone <- struct{}{}: + return c.reader(ctx) + } + } + return 0, nil, xerrors.Errorf("previous reader was not read to EOF") + } + return MessageType(h.opcode), &messageReader{ + h: &h, + c: c, + }, nil } } // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { - ctx context.Context - c *Conn + maskPos int + h *header + c *Conn + eofed bool } // Read reads as many bytes as possible into p. -func (r messageReader) Read(p []byte) (int, error) { +func (r *messageReader) Read(p []byte) (int, error) { n, err := r.read(p) if err != nil { // Have to return io.EOF directly for now, we cannot wrap as xerrors // isn't used in stdlib. - if err == io.EOF { + if xerrors.Is(err, io.EOF) { return n, io.EOF } return n, xerrors.Errorf("failed to read: %w", err) @@ -592,31 +572,62 @@ func (r messageReader) Read(p []byte) (int, error) { return n, nil } -func (r messageReader) read(p []byte) (_ int, err error) { - if atomic.LoadInt64(&r.c.activeReader) == 0 { - return 0, io.EOF +func (r *messageReader) read(p []byte) (int, error) { + if r.eofed { + return 0, xerrors.Errorf("cannot use EOFed reader") } - select { - case <-r.c.closed: - return 0, r.c.closeErr - case r.c.readBytes <- p: + if r.h == nil { select { - case <-r.ctx.Done(): - r.c.close(xerrors.Errorf("data read timed out: %w", r.ctx.Err())) - // Wait for readLoop to complete so we know p is done. - <-r.c.readDone - return 0, r.ctx.Err() - case n, ok := <-r.c.readDone: - if !ok { - return 0, r.c.closeErr + case <-r.c.closed: + return 0, r.c.closeErr + case h := <-r.c.readData: + if h.opcode != opContinuation { + ce := CloseError{ + Code: StatusProtocolError, + Reason: "cannot read new data frame when previous frame is not finished", + } + r.c.Close(ce.Code, ce.Reason) + return 0, ce } - if atomic.LoadInt64(&r.c.activeReader) == 0 { + r.h = &h + } + } + + if int64(len(p)) > r.h.payloadLength { + p = p[:r.h.payloadLength] + } + + n, err := io.ReadFull(r.c.br, p) + + r.h.payloadLength -= int64(n) + if r.h.masked { + r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + } + + if err != nil { + r.c.close(xerrors.Errorf("failed to read control frame payload: %w", err)) + return n, r.c.closeErr + } + + if r.h.payloadLength == 0 { + select { + case <-r.c.closed: + return n, r.c.closeErr + case r.c.readDone <- struct{}{}: + } + if r.h.fin { + r.eofed = true + select { + case <-r.c.closed: + return n, r.c.closeErr + case r.c.setReadTimeout <- context.Background(): return n, io.EOF } - return n, nil } - case <-r.ctx.Done(): - return 0, r.ctx.Err() + r.maskPos = 0 + r.h = nil } + + return n, nil }