diff --git a/.gitignore b/.gitignore index 70d8e7030c7c59458ca2741bea53c61e7ff22715..35ecb6b04d5cfb5e5df8b468368ccb5ac941e294 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ coverage.html wstest_reports +websocket.test diff --git a/bench_test.go b/bench_test.go index f5b5b2195a41b5ad5b9fe1ba4d9b768e74ca1189..66331e0c72e186950ec727e333e3dcb99aca8da5 100644 --- a/bench_test.go +++ b/bench_test.go @@ -4,10 +4,12 @@ import ( "context" "io" "net/http" - "nhooyr.io/websocket" + "strconv" "strings" "testing" "time" + + "nhooyr.io/websocket" ) func BenchmarkConn(b *testing.B) { @@ -36,42 +38,50 @@ func BenchmarkConn(b *testing.B) { } defer c.Close(websocket.StatusInternalError, "") - msg := strings.Repeat("2", 4096*16) - buf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.StartTimer() - for i := 0; i < b.N; i++ { - w, err := c.Write(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } + runN := func(n int) { + b.Run(strconv.Itoa(n), func(b *testing.B) { + msg := []byte(strings.Repeat("2", n)) + buf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + w, err := c.Write(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } - _, err = io.WriteString(w, msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) + } - _, r, err := c.Read(ctx) - if err != nil { - b.Fatal(err, b.N) - } + err = w.Close() + if err != nil { + b.Fatal(err) + } - _, err = io.ReadFull(r, buf) - if err != nil { - b.Fatal(err) - } + _, r, err := c.Read(ctx) + if err != nil { + b.Fatal(err, b.N) + } - // TODO jank - _, err = r.Read(nil) - if err != io.EOF { - b.Fatalf("wtf %q", err) - } + _, err = io.ReadFull(r, buf) + if err != nil { + b.Fatal(err) + } + } + b.StopTimer() + }) } - b.StopTimer() + + runN(32) + runN(128) + runN(512) + runN(1024) + runN(4096) + runN(16384) + runN(65536) + runN(131072) + c.Close(websocket.StatusNormalClosure, "") } diff --git a/dial_test.go b/dial_test.go index 48c1c3125a4b33b5eaf4eb7db82e447d29e08c54..02aaa4fc874df6ed826027cfa9e26a52b82d9f2a 100644 --- a/dial_test.go +++ b/dial_test.go @@ -7,6 +7,8 @@ import ( ) func Test_verifyServerHandshake(t *testing.T) { + t.Parallel() + testCases := []struct { name string response func(w http.ResponseWriter) diff --git a/example_test.go b/example_test.go index 85cd3aa123888599b176cd9bbae44b98f649001e..c343d78f3b86c956937e9506600a119668665113 100644 --- a/example_test.go +++ b/example_test.go @@ -79,7 +79,7 @@ func ExampleAccept() { log.Printf("server handshake failed: %v", err) return } - defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error. + defer c.Close(websocket.StatusInternalError, "") jc := websocket.JSONConn{ Conn: c, diff --git a/header.go b/header.go index 276fa0c30b93f6120c18d064c96b1c5e05548f1d..82ad5f561431e322fcdd402f9c3c6cfade64f49d 100644 --- a/header.go +++ b/header.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "math" "golang.org/x/xerrors" ) @@ -55,7 +56,7 @@ func marshalHeader(h header) []byte { panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) case h.payloadLength <= 125: b[1] = byte(h.payloadLength) - case h.payloadLength <= 1<<16: + case h.payloadLength <= math.MaxUint16: b[1] = 126 b = b[:len(b)+2] binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) @@ -105,10 +106,8 @@ func readHeader(r io.Reader) (header, error) { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: - h.payloadLength = 126 extra += 2 case payloadLength == 127: - h.payloadLength = 127 extra += 8 } diff --git a/header_test.go b/header_test.go index b4d0769fb2c33044fd841c04041f58a8d8328028..b9cf351b93255e09c1036be4357af0d0c6f94ef1 100644 --- a/header_test.go +++ b/header_test.go @@ -3,6 +3,7 @@ package websocket import ( "bytes" "math/rand" + "strconv" "testing" "time" @@ -36,10 +37,38 @@ func TestHeader(t *testing.T) { t.Fatalf("unexpected error value: %+v", err) } }) + + t.Run("lengths", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 124, + 125, + 126, + 4096, + 16384, + 65535, + 65536, + 65537, + 131072, + } + + for _, n := range lengths { + n := n + t.Run(strconv.Itoa(n), func(t *testing.T) { + t.Parallel() + + testHeader(t, header{ + payloadLength: int64(n), + }) + }) + } + }) + t.Run("fuzz", func(t *testing.T) { t.Parallel() - for i := 0; i < 1000; i++ { + for i := 0; i < 10000; i++ { h := header{ fin: randBool(), rsv1: randBool(), @@ -55,20 +84,24 @@ func TestHeader(t *testing.T) { rand.Read(h.maskKey[:]) } - b := marshalHeader(h) - r := bytes.NewReader(b) - h2, err := readHeader(r) - if err != nil { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read header: %v", err) - } - - if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) - } + testHeader(t, h) } }) } + +func testHeader(t *testing.T, h header) { + b := marshalHeader(h) + r := bytes.NewReader(b) + h2, err := readHeader(r) + if err != nil { + t.Logf("header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("failed to read header: %v", err) + } + + if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { + t.Logf("header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) + } +} diff --git a/websocket.go b/websocket.go index 52f42dc80920a3b9bf11f6f1d8ec1d5244350b5b..79923518038ae6b7fd0052d29e203cbaac4bcc7e 100644 --- a/websocket.go +++ b/websocket.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "log" "runtime" "sync" "sync/atomic" @@ -38,19 +37,20 @@ type Conn struct { // on writeBytes. // Send on control to write a control message. // writeDone will be sent back when the message is written - // Close writeBytes to flush the message and wait for a - // ping on writeDone. // TODO should I care about this allocation? + // Send on writeFlush to flush the message and wait for a + // ping on writeDone. // writeDone will be closed if the data message write errors. write chan MessageType control chan control writeBytes chan []byte writeDone chan struct{} + writeFlush chan struct{} // Readers should receive on read to begin reading a message. // Then send a byte slice to readBytes to read into it. // The n of bytes read will be sent on readDone once the read into a slice is complete. // readDone will be closed if the read fails. - // readInProgress will be set to 0 on io.EOF. + // activeReader will be set to 0 on io.EOF. activeReader int64 inMsg bool read chan opcode @@ -86,7 +86,9 @@ func (c *Conn) init() { c.write = make(chan MessageType) c.control = make(chan control) + c.writeBytes = make(chan []byte) c.writeDone = make(chan struct{}) + c.writeFlush = make(chan struct{}) c.read = make(chan opcode) c.readBytes = make(chan []byte) @@ -128,8 +130,6 @@ func (c *Conn) writeLoop() { messageLoop: for { - c.writeBytes = make(chan []byte) - var dataType MessageType select { case <-c.closed: @@ -170,9 +170,9 @@ messageLoop: case c.writeDone <- struct{}{}: continue } - case b, ok := <-c.writeBytes: + case b := <-c.writeBytes: h := header{ - fin: !ok, + fin: false, opcode: opcode(dataType), payloadLength: int64(len(b)), masked: c.client, @@ -183,30 +183,41 @@ messageLoop: } firstSent = true - if c.client { - log.Printf("client %#v", h) - } - c.writeFrame(h, b) - if !ok { - err := c.bw.Flush() - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return - } + select { + case <-c.closed: + return + case c.writeDone <- struct{}{}: + continue + } + case <-c.writeFlush: + h := header{ + fin: true, + opcode: opcode(dataType), + payloadLength: 0, + masked: c.client, + } + + if firstSent { + h.opcode = opContinuation } + c.writeFrame(h, nil) + select { case <-c.closed: return case c.writeDone <- struct{}{}: - if ok { - continue - } else { - continue messageLoop - } } + + err := c.bw.Flush() + if err != nil { + c.close(xerrors.Errorf("failed to write to connection: %w", err)) + return + } + + continue messageLoop } } } @@ -264,10 +275,6 @@ func (c *Conn) readLoop() { return } - if !c.client { - log.Printf("%#v", h) - } - if h.rsv1 || h.rsv2 || h.rsv3 { c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) return @@ -480,7 +487,14 @@ func (w messageWriter) Write(p []byte) (int, error) { // Close flushes the frame to the connection. // This must be called for every messageWriter. func (w messageWriter) Close() error { - close(w.c.writeBytes) + select { + case <-w.c.closed: + return w.c.closeErr + case <-w.ctx.Done(): + return w.ctx.Err() + case w.c.writeFlush <- struct{}{}: + } + select { case <-w.c.closed: return w.c.closeErr @@ -499,8 +513,25 @@ func (w messageWriter) Close() error { // Please ensure to read the full message from io.Reader. // You can only read a single message at a time. func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) { - if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { - return 0, nil, xerrors.New("websocket: previous message not fully read") + for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { + select { + case <-c.closed: + return 0, nil, c.closeErr + case c.readBytes <- nil: + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case _, ok := <-c.readDone: + if !ok { + return 0, nil, c.closeErr + } + if atomic.LoadInt64(&c.activeReader) == 1 { + return 0, nil, xerrors.New("websocket: previous message not fully read") + } + } + case <-ctx.Done(): + return 0, nil, ctx.Err() + } } select { @@ -530,7 +561,7 @@ func (r messageReader) Read(p []byte) (int, error) { if err == io.EOF { return n, io.EOF } - return n, xerrors.Errorf("failed to read: %w", err) + return n, xerrors.Errorf("websocket: failed to read: %w", err) } return n, nil } @@ -546,7 +577,7 @@ func (r messageReader) read(p []byte) (_ int, err error) { case r.c.readBytes <- p: select { case <-r.ctx.Done(): - r.c.close(xerrors.Errorf("read timed out: %w", err)) + r.c.close(xerrors.Errorf("read timed out: %w", r.ctx.Err())) // Wait for readloop to complete so we know p is done. <-r.c.readDone return 0, r.ctx.Err() diff --git a/websocket_test.go b/websocket_test.go index dba0182e8ef23dc80ee1cbdbf76042fd34273b4c..d6d222d55e9aac2e88736975b36ee1ca7da7428b 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -197,14 +196,25 @@ func TestHandshake(t *testing.T) { ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) defer cancel() - jc := websocket.JSONConn{ - Conn: c, - } + write := func() error { + jc := websocket.JSONConn{ + Conn: c, + } - v := map[string]interface{}{ - "anmol": "wowow", + v := map[string]interface{}{ + "anmol": "wowow", + } + err = jc.Write(ctx, v) + if err != nil { + return err + } + return nil } - err = jc.Write(ctx, v) + err = write() + if err != nil { + return err + } + err = write() if err != nil { return err } @@ -223,17 +233,29 @@ func TestHandshake(t *testing.T) { Conn: c, } - var v interface{} - err = jc.Read(ctx, &v) + read := func() error { + var v interface{} + err = jc.Read(ctx, &v) + if err != nil { + return err + } + + exp := map[string]interface{}{ + "anmol": "wowow", + } + if !reflect.DeepEqual(exp, v) { + return xerrors.Errorf("expected %v but got %v", exp, v) + } + return nil + } + err = read() if err != nil { return err } - - exp := map[string]interface{}{ - "anmol": "wowow", - } - if !reflect.DeepEqual(exp, v) { - return xerrors.Errorf("expected %v but got %v", exp, v) + // Read twice to ensure the un EOFed previous reader works correctly. + err = read() + if err != nil { + return err } c.Close(websocket.StatusNormalClosure, "") @@ -399,10 +421,11 @@ func TestAutobahnServer(t *testing.T) { func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") - echo := func() error { - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + b := make([]byte, 32768) + echo := func() error { typ, r, err := c.Read(ctx) if err != nil { return err @@ -413,10 +436,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { return err } - b1, _ := ioutil.ReadAll(r) - log.Printf("%q", b1) - - _, err = io.Copy(w, r) + _, err = io.CopyBuffer(w, r, b) if err != nil { return err } @@ -429,14 +449,11 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { return nil } - var i int for { err := echo() if err != nil { - log.Println("WTF", err, i) return } - i++ } }