good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 648808d6 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Solve all remaining TODOs in an elegant fashion

parent f685c8d7
No related branches found
No related tags found
No related merge requests found
coverage.html coverage.html
wstest_reports wstest_reports
websocket.test
...@@ -4,10 +4,12 @@ import ( ...@@ -4,10 +4,12 @@ import (
"context" "context"
"io" "io"
"net/http" "net/http"
"nhooyr.io/websocket" "strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
"nhooyr.io/websocket"
) )
func BenchmarkConn(b *testing.B) { func BenchmarkConn(b *testing.B) {
...@@ -36,17 +38,19 @@ func BenchmarkConn(b *testing.B) { ...@@ -36,17 +38,19 @@ func BenchmarkConn(b *testing.B) {
} }
defer c.Close(websocket.StatusInternalError, "") defer c.Close(websocket.StatusInternalError, "")
msg := strings.Repeat("2", 4096*16) runN := func(n int) {
b.Run(strconv.Itoa(n), func(b *testing.B) {
msg := []byte(strings.Repeat("2", n))
buf := make([]byte, len(msg)) buf := make([]byte, len(msg))
b.SetBytes(int64(len(msg))) b.SetBytes(int64(len(msg)))
b.StartTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
w, err := c.Write(ctx, websocket.MessageText) w, err := c.Write(ctx, websocket.MessageText)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
_, err = io.WriteString(w, msg) _, err = w.Write(msg)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
...@@ -65,13 +69,19 @@ func BenchmarkConn(b *testing.B) { ...@@ -65,13 +69,19 @@ func BenchmarkConn(b *testing.B) {
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
// TODO jank
_, err = r.Read(nil)
if err != io.EOF {
b.Fatalf("wtf %q", err)
}
} }
b.StopTimer() b.StopTimer()
})
}
runN(32)
runN(128)
runN(512)
runN(1024)
runN(4096)
runN(16384)
runN(65536)
runN(131072)
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
} }
...@@ -7,6 +7,8 @@ import ( ...@@ -7,6 +7,8 @@ import (
) )
func Test_verifyServerHandshake(t *testing.T) { func Test_verifyServerHandshake(t *testing.T) {
t.Parallel()
testCases := []struct { testCases := []struct {
name string name string
response func(w http.ResponseWriter) response func(w http.ResponseWriter)
......
...@@ -79,7 +79,7 @@ func ExampleAccept() { ...@@ -79,7 +79,7 @@ func ExampleAccept() {
log.Printf("server handshake failed: %v", err) log.Printf("server handshake failed: %v", err)
return return
} }
defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error. defer c.Close(websocket.StatusInternalError, "")
jc := websocket.JSONConn{ jc := websocket.JSONConn{
Conn: c, Conn: c,
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"math"
"golang.org/x/xerrors" "golang.org/x/xerrors"
) )
...@@ -55,7 +56,7 @@ func marshalHeader(h header) []byte { ...@@ -55,7 +56,7 @@ func marshalHeader(h header) []byte {
panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength))
case h.payloadLength <= 125: case h.payloadLength <= 125:
b[1] = byte(h.payloadLength) b[1] = byte(h.payloadLength)
case h.payloadLength <= 1<<16: case h.payloadLength <= math.MaxUint16:
b[1] = 126 b[1] = 126
b = b[:len(b)+2] b = b[:len(b)+2]
binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength))
...@@ -105,10 +106,8 @@ func readHeader(r io.Reader) (header, error) { ...@@ -105,10 +106,8 @@ func readHeader(r io.Reader) (header, error) {
case payloadLength < 126: case payloadLength < 126:
h.payloadLength = int64(payloadLength) h.payloadLength = int64(payloadLength)
case payloadLength == 126: case payloadLength == 126:
h.payloadLength = 126
extra += 2 extra += 2
case payloadLength == 127: case payloadLength == 127:
h.payloadLength = 127
extra += 8 extra += 8
} }
......
...@@ -3,6 +3,7 @@ package websocket ...@@ -3,6 +3,7 @@ package websocket
import ( import (
"bytes" "bytes"
"math/rand" "math/rand"
"strconv"
"testing" "testing"
"time" "time"
...@@ -36,10 +37,38 @@ func TestHeader(t *testing.T) { ...@@ -36,10 +37,38 @@ func TestHeader(t *testing.T) {
t.Fatalf("unexpected error value: %+v", err) t.Fatalf("unexpected error value: %+v", err)
} }
}) })
t.Run("lengths", func(t *testing.T) {
t.Parallel()
lengths := []int{
124,
125,
126,
4096,
16384,
65535,
65536,
65537,
131072,
}
for _, n := range lengths {
n := n
t.Run(strconv.Itoa(n), func(t *testing.T) {
t.Parallel()
testHeader(t, header{
payloadLength: int64(n),
})
})
}
})
t.Run("fuzz", func(t *testing.T) { t.Run("fuzz", func(t *testing.T) {
t.Parallel() t.Parallel()
for i := 0; i < 1000; i++ { for i := 0; i < 10000; i++ {
h := header{ h := header{
fin: randBool(), fin: randBool(),
rsv1: randBool(), rsv1: randBool(),
...@@ -55,6 +84,12 @@ func TestHeader(t *testing.T) { ...@@ -55,6 +84,12 @@ func TestHeader(t *testing.T) {
rand.Read(h.maskKey[:]) rand.Read(h.maskKey[:])
} }
testHeader(t, h)
}
})
}
func testHeader(t *testing.T, h header) {
b := marshalHeader(h) b := marshalHeader(h)
r := bytes.NewReader(b) r := bytes.NewReader(b)
h2, err := readHeader(r) h2, err := readHeader(r)
...@@ -70,5 +105,3 @@ func TestHeader(t *testing.T) { ...@@ -70,5 +105,3 @@ func TestHeader(t *testing.T) {
t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
} }
} }
})
}
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
...@@ -38,19 +37,20 @@ type Conn struct { ...@@ -38,19 +37,20 @@ type Conn struct {
// on writeBytes. // on writeBytes.
// Send on control to write a control message. // Send on control to write a control message.
// writeDone will be sent back when the message is written // writeDone will be sent back when the message is written
// Close writeBytes to flush the message and wait for a // Send on writeFlush to flush the message and wait for a
// ping on writeDone. // TODO should I care about this allocation? // ping on writeDone.
// writeDone will be closed if the data message write errors. // writeDone will be closed if the data message write errors.
write chan MessageType write chan MessageType
control chan control control chan control
writeBytes chan []byte writeBytes chan []byte
writeDone chan struct{} writeDone chan struct{}
writeFlush chan struct{}
// Readers should receive on read to begin reading a message. // Readers should receive on read to begin reading a message.
// Then send a byte slice to readBytes to read into it. // Then send a byte slice to readBytes to read into it.
// The n of bytes read will be sent on readDone once the read into a slice is complete. // The n of bytes read will be sent on readDone once the read into a slice is complete.
// readDone will be closed if the read fails. // readDone will be closed if the read fails.
// readInProgress will be set to 0 on io.EOF. // activeReader will be set to 0 on io.EOF.
activeReader int64 activeReader int64
inMsg bool inMsg bool
read chan opcode read chan opcode
...@@ -86,7 +86,9 @@ func (c *Conn) init() { ...@@ -86,7 +86,9 @@ func (c *Conn) init() {
c.write = make(chan MessageType) c.write = make(chan MessageType)
c.control = make(chan control) c.control = make(chan control)
c.writeBytes = make(chan []byte)
c.writeDone = make(chan struct{}) c.writeDone = make(chan struct{})
c.writeFlush = make(chan struct{})
c.read = make(chan opcode) c.read = make(chan opcode)
c.readBytes = make(chan []byte) c.readBytes = make(chan []byte)
...@@ -128,8 +130,6 @@ func (c *Conn) writeLoop() { ...@@ -128,8 +130,6 @@ func (c *Conn) writeLoop() {
messageLoop: messageLoop:
for { for {
c.writeBytes = make(chan []byte)
var dataType MessageType var dataType MessageType
select { select {
case <-c.closed: case <-c.closed:
...@@ -170,9 +170,9 @@ messageLoop: ...@@ -170,9 +170,9 @@ messageLoop:
case c.writeDone <- struct{}{}: case c.writeDone <- struct{}{}:
continue continue
} }
case b, ok := <-c.writeBytes: case b := <-c.writeBytes:
h := header{ h := header{
fin: !ok, fin: false,
opcode: opcode(dataType), opcode: opcode(dataType),
payloadLength: int64(len(b)), payloadLength: int64(len(b)),
masked: c.client, masked: c.client,
...@@ -183,30 +183,41 @@ messageLoop: ...@@ -183,30 +183,41 @@ messageLoop:
} }
firstSent = true firstSent = true
if c.client {
log.Printf("client %#v", h)
}
c.writeFrame(h, b) c.writeFrame(h, b)
if !ok { select {
err := c.bw.Flush() case <-c.closed:
if err != nil {
c.close(xerrors.Errorf("failed to write to connection: %w", err))
return return
case c.writeDone <- struct{}{}:
continue
} }
case <-c.writeFlush:
h := header{
fin: true,
opcode: opcode(dataType),
payloadLength: 0,
masked: c.client,
} }
if firstSent {
h.opcode = opContinuation
}
c.writeFrame(h, nil)
select { select {
case <-c.closed: case <-c.closed:
return return
case c.writeDone <- struct{}{}: case c.writeDone <- struct{}{}:
if ok {
continue
} else {
continue messageLoop
} }
err := c.bw.Flush()
if err != nil {
c.close(xerrors.Errorf("failed to write to connection: %w", err))
return
} }
continue messageLoop
} }
} }
} }
...@@ -264,10 +275,6 @@ func (c *Conn) readLoop() { ...@@ -264,10 +275,6 @@ func (c *Conn) readLoop() {
return return
} }
if !c.client {
log.Printf("%#v", h)
}
if h.rsv1 || h.rsv2 || h.rsv3 { if h.rsv1 || h.rsv2 || h.rsv3 {
c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
return return
...@@ -480,7 +487,14 @@ func (w messageWriter) Write(p []byte) (int, error) { ...@@ -480,7 +487,14 @@ func (w messageWriter) Write(p []byte) (int, error) {
// Close flushes the frame to the connection. // Close flushes the frame to the connection.
// This must be called for every messageWriter. // This must be called for every messageWriter.
func (w messageWriter) Close() error { func (w messageWriter) Close() error {
close(w.c.writeBytes) select {
case <-w.c.closed:
return w.c.closeErr
case <-w.ctx.Done():
return w.ctx.Err()
case w.c.writeFlush <- struct{}{}:
}
select { select {
case <-w.c.closed: case <-w.c.closed:
return w.c.closeErr return w.c.closeErr
...@@ -499,9 +513,26 @@ func (w messageWriter) Close() error { ...@@ -499,9 +513,26 @@ func (w messageWriter) Close() error {
// Please ensure to read the full message from io.Reader. // Please ensure to read the full message from io.Reader.
// You can only read a single message at a time. // You can only read a single message at a time.
func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) { func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) {
if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
select {
case <-c.closed:
return 0, nil, c.closeErr
case c.readBytes <- nil:
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case _, ok := <-c.readDone:
if !ok {
return 0, nil, c.closeErr
}
if atomic.LoadInt64(&c.activeReader) == 1 {
return 0, nil, xerrors.New("websocket: previous message not fully read") return 0, nil, xerrors.New("websocket: previous message not fully read")
} }
}
case <-ctx.Done():
return 0, nil, ctx.Err()
}
}
select { select {
case <-c.closed: case <-c.closed:
...@@ -530,7 +561,7 @@ func (r messageReader) Read(p []byte) (int, error) { ...@@ -530,7 +561,7 @@ func (r messageReader) Read(p []byte) (int, error) {
if err == io.EOF { if err == io.EOF {
return n, io.EOF return n, io.EOF
} }
return n, xerrors.Errorf("failed to read: %w", err) return n, xerrors.Errorf("websocket: failed to read: %w", err)
} }
return n, nil return n, nil
} }
...@@ -546,7 +577,7 @@ func (r messageReader) read(p []byte) (_ int, err error) { ...@@ -546,7 +577,7 @@ func (r messageReader) read(p []byte) (_ int, err error) {
case r.c.readBytes <- p: case r.c.readBytes <- p:
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
r.c.close(xerrors.Errorf("read timed out: %w", err)) r.c.close(xerrors.Errorf("read timed out: %w", r.ctx.Err()))
// Wait for readloop to complete so we know p is done. // Wait for readloop to complete so we know p is done.
<-r.c.readDone <-r.c.readDone
return 0, r.ctx.Err() return 0, r.ctx.Err()
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/http/httptest" "net/http/httptest"
...@@ -197,6 +196,7 @@ func TestHandshake(t *testing.T) { ...@@ -197,6 +196,7 @@ func TestHandshake(t *testing.T) {
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
defer cancel() defer cancel()
write := func() error {
jc := websocket.JSONConn{ jc := websocket.JSONConn{
Conn: c, Conn: c,
} }
...@@ -208,6 +208,16 @@ func TestHandshake(t *testing.T) { ...@@ -208,6 +208,16 @@ func TestHandshake(t *testing.T) {
if err != nil { if err != nil {
return err return err
} }
return nil
}
err = write()
if err != nil {
return err
}
err = write()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
return nil return nil
...@@ -223,6 +233,7 @@ func TestHandshake(t *testing.T) { ...@@ -223,6 +233,7 @@ func TestHandshake(t *testing.T) {
Conn: c, Conn: c,
} }
read := func() error {
var v interface{} var v interface{}
err = jc.Read(ctx, &v) err = jc.Read(ctx, &v)
if err != nil { if err != nil {
...@@ -235,6 +246,17 @@ func TestHandshake(t *testing.T) { ...@@ -235,6 +246,17 @@ func TestHandshake(t *testing.T) {
if !reflect.DeepEqual(exp, v) { if !reflect.DeepEqual(exp, v) {
return xerrors.Errorf("expected %v but got %v", exp, v) return xerrors.Errorf("expected %v but got %v", exp, v)
} }
return nil
}
err = read()
if err != nil {
return err
}
// Read twice to ensure the un EOFed previous reader works correctly.
err = read()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
return nil return nil
...@@ -399,10 +421,11 @@ func TestAutobahnServer(t *testing.T) { ...@@ -399,10 +421,11 @@ func TestAutobahnServer(t *testing.T) {
func echoLoop(ctx context.Context, c *websocket.Conn) { func echoLoop(ctx context.Context, c *websocket.Conn) {
defer c.Close(websocket.StatusInternalError, "") defer c.Close(websocket.StatusInternalError, "")
echo := func() error {
ctx, cancel := context.WithTimeout(ctx, time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() defer cancel()
b := make([]byte, 32768)
echo := func() error {
typ, r, err := c.Read(ctx) typ, r, err := c.Read(ctx)
if err != nil { if err != nil {
return err return err
...@@ -413,10 +436,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { ...@@ -413,10 +436,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn) {
return err return err
} }
b1, _ := ioutil.ReadAll(r) _, err = io.CopyBuffer(w, r, b)
log.Printf("%q", b1)
_, err = io.Copy(w, r)
if err != nil { if err != nil {
return err return err
} }
...@@ -429,14 +449,11 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { ...@@ -429,14 +449,11 @@ func echoLoop(ctx context.Context, c *websocket.Conn) {
return nil return nil
} }
var i int
for { for {
err := echo() err := echo()
if err != nil { if err != nil {
log.Println("WTF", err, i)
return return
} }
i++
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment