diff --git a/websocket.go b/conn.go similarity index 99% rename from websocket.go rename to conn.go index bbadb9bc619ebb66117d63ac4f1a8fdcd7d5c61a..bc115e38f4d8b328c3c6848c831cf3bdf93ebbb6 100644 --- a/websocket.go +++ b/conn.go @@ -138,7 +138,7 @@ func (c *Conn) close(err error) { // closeErr. c.closer.Close() - // See comment in dial.go + // See comment on bufioReaderPool in handshake.go if c.client { // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer // and we can safely return them. diff --git a/websocket_test.go b/conn_test.go similarity index 84% rename from websocket_test.go rename to conn_test.go index 2fabba545ebc9a03be2b75a64e622b2aa3504b82..8846979d5dc639836533f315f2ec90d005b412e8 100644 --- a/websocket_test.go +++ b/conn_test.go @@ -12,10 +12,13 @@ import ( "io" "io/ioutil" "math/rand" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" + "os" + "os/exec" "reflect" "strconv" "strings" @@ -1962,3 +1965,369 @@ func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.Mes } return assertEqualf(p, actP, "unexpected frame %v payload", actTyp) } + +func BenchmarkConn(b *testing.B) { + sizes := []int{ + 2, + 16, + 32, + 512, + 4096, + 16384, + } + + b.Run("write", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + b.Run("stream", func(b *testing.B) { + benchConn(b, false, true, size) + }) + b.Run("buffer", func(b *testing.B) { + benchConn(b, false, false, size) + }) + }) + } + }) + + b.Run("echo", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchConn(b, true, true, size) + }) + } + }) +} + +func benchConn(b *testing.B, echo, stream bool, size int) { + s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + if echo { + wsecho.Loop(r.Context(), c) + } else { + discardLoop(r.Context(), c) + } + return nil + }, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + b.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + msg := []byte(strings.Repeat("2", size)) + readBuf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if stream { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } + + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) + } + + err = w.Close() + if err != nil { + b.Fatal(err) + } + } else { + err = c.Write(ctx, websocket.MessageText, msg) + if err != nil { + b.Fatal(err) + } + } + + if echo { + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err) + } + + _, err = io.ReadFull(r, readBuf) + if err != nil { + b.Fatal(err) + } + } + } + b.StopTimer() + + c.Close(websocket.StatusNormalClosure, "") +} + +func discardLoop(ctx context.Context, c *websocket.Conn) { + defer c.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32768) + echo := func() error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + + _, err = io.CopyBuffer(ioutil.Discard, r, b) + if err != nil { + return err + } + return nil + } + + for { + err := echo() + if err != nil { + return + } + } +} + +func TestAutobahnPython(t *testing.T) { + // This test contains the old autobahn test suite tests that use the + // python binary. The approach is clunky and slow so new tests + // have been written in pure Go in websocket_test.go. + // These have been kept for correctness purposes and are occasionally ran. + if os.Getenv("AUTOBAHN_PYTHON") == "" { + t.Skip("Set $AUTOBAHN_PYTHON to run tests against the python autobahn test suite") + } + + t.Run("server", testServerAutobahnPython) + t.Run("client", testClientAutobahnPython) +} + +// https://github.com/crossbario/autobahn-python/tree/master/wstest +func testServerAutobahnPython(t *testing.T) { + t.Parallel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) + if err != nil { + t.Logf("server handshake failed: %+v", err) + return + } + wsecho.Loop(r.Context(), c) + })) + defer s.Close() + + spec := map[string]interface{}{ + "outdir": "ci/out/wstestServerReports", + "servers": []interface{}{ + map[string]interface{}{ + "agent": "main", + "url": strings.Replace(s.URL, "http", "ws", 1), + }, + }, + "cases": []string{"*"}, + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. 7.5.1 is the same. + // 12.* and 13.* as we do not support compression. + "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + wstest := exec.CommandContext(ctx, "wstest", args...) + out, err := wstest.CombinedOutput() + if err != nil { + t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + } + + checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") +} + +func unusedListenAddr() (string, error) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py +func testClientAutobahnPython(t *testing.T) { + t.Parallel() + + if os.Getenv("AUTOBAHN_PYTHON") == "" { + t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") + } + + serverAddr, err := unusedListenAddr() + if err != nil { + t.Fatalf("failed to get unused listen addr for wstest: %v", err) + } + + wsServerURL := "ws://" + serverAddr + + spec := map[string]interface{}{ + "url": wsServerURL, + "outdir": "ci/out/wstestClientReports", + "cases": []string{"*"}, + // See TestAutobahnServer for the reasons why we exclude these. + "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), + // Disables some server that runs as part of fuzzingserver mode. + // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 + "--webport=0", + } + wstest := exec.CommandContext(ctx, "wstest", args...) + err = wstest.Start() + if err != nil { + t.Fatal(err) + } + defer func() { + err := wstest.Process.Kill() + if err != nil { + t.Error(err) + } + }() + + // Let it come up. + time.Sleep(time.Second * 5) + + var cases int + func() { + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + _, r, err := c.Reader(ctx) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + cases, err = strconv.Atoi(string(b)) + if err != nil { + t.Fatal(err) + } + + c.Close(websocket.StatusNormalClosure, "") + }() + + for i := 1; i <= cases; i++ { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*45) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) + if err != nil { + t.Fatal(err) + } + wsecho.Loop(ctx, c) + }() + } + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) + if err != nil { + t.Fatal(err) + } + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func checkWSTestIndex(t *testing.T, path string) { + wstestOut, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("failed to read index.json: %v", err) + } + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + BehaviorClose string `json:"behaviorClose"` + } + err = json.Unmarshal(wstestOut, &indexJSON) + if err != nil { + t.Fatalf("failed to unmarshal index.json: %v", err) + } + + var failed bool + for _, tests := range indexJSON { + for test, result := range tests { + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + failed = true + t.Errorf("test %v failed", test) + } + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + failed = true + t.Errorf("bad close behaviour for test %v", test) + } + } + } + + if failed { + path = strings.Replace(path, ".json", ".html", 1) + if os.Getenv("CI") == "" { + t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) + } + } +} diff --git a/dial.go b/dial.go deleted file mode 100644 index 79232aac86c0bd4dcbeea2137a01898eaee3bd3e..0000000000000000000000000000000000000000 --- a/dial.go +++ /dev/null @@ -1,200 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "bytes" - "context" - "encoding/base64" - "fmt" - "io" - "io/ioutil" - "math/rand" - "net/http" - "net/url" - "strings" - "sync" -) - -// DialOptions represents the options available to pass to Dial. -type DialOptions struct { - // HTTPClient is the http client used for the handshake. - // Its Transport must return writable bodies - // for WebSocket handshakes. - // http.Transport does this correctly beginning with Go 1.12. - HTTPClient *http.Client - - // HTTPHeader specifies the HTTP headers included in the handshake request. - HTTPHeader http.Header - - // Subprotocols lists the subprotocols to negotiate with the server. - Subprotocols []string -} - -// Dial performs a WebSocket handshake on the given url with the given options. -// The response is the WebSocket handshake response from the server. -// If an error occurs, the returned response may be non nil. However, you can only -// read the first 1024 bytes of its body. -// -// You never need to close the resp.Body yourself. -// -// This function requires at least Go 1.12 to succeed as it uses a new feature -// in net/http to perform WebSocket handshakes and get a writable body -// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - c, r, err := dial(ctx, u, opts) - if err != nil { - return nil, r, fmt.Errorf("failed to websocket dial: %w", err) - } - return c, r, nil -} - -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - if opts == nil { - opts = &DialOptions{} - } - - // Shallow copy to ensure defaults do not affect user passed options. - opts2 := *opts - opts = &opts2 - - if opts.HTTPClient == nil { - opts.HTTPClient = http.DefaultClient - } - if opts.HTTPClient.Timeout > 0 { - return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - - parsedURL, err := url.Parse(u) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse url: %w", err) - } - - switch parsedURL.Scheme { - case "ws": - parsedURL.Scheme = "http" - case "wss": - parsedURL.Scheme = "https" - default: - return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) - } - - req, _ := http.NewRequest("GET", parsedURL.String(), nil) - req = req.WithContext(ctx) - req.Header = opts.HTTPHeader - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey()) - if len(opts.Subprotocols) > 0 { - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) - } - - resp, err := opts.HTTPClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) - } - defer func() { - if err != nil { - // We read a bit of the body for easier debugging. - r := io.LimitReader(resp.Body, 1024) - b, _ := ioutil.ReadAll(r) - resp.Body.Close() - resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - } - }() - - err = verifyServerResponse(req, resp) - if err != nil { - return nil, resp, err - } - - rwc, ok := resp.Body.(io.ReadWriteCloser) - if !ok { - return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) - } - - c := &Conn{ - subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: getBufioReader(rwc), - bw: getBufioWriter(rwc), - closer: rwc, - client: true, - } - c.extractBufioWriterBuf(rwc) - c.init() - - return c, resp, nil -} - -func verifyServerResponse(r *http.Request, resp *http.Response) error { - if resp.StatusCode != http.StatusSwitchingProtocols { - return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) - } - - if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { - return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) - } - - if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { - return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) - } - - if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", - resp.Header.Get("Sec-WebSocket-Accept"), - r.Header.Get("Sec-WebSocket-Key"), - ) - } - - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) - } - - return nil -} - -// The below pools can only be used by the client because http.Hijacker will always -// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top. - -var bufioReaderPool = sync.Pool{ - New: func() interface{} { - return bufio.NewReader(nil) - }, -} - -func getBufioReader(r io.Reader) *bufio.Reader { - br := bufioReaderPool.Get().(*bufio.Reader) - br.Reset(r) - return br -} - -func returnBufioReader(br *bufio.Reader) { - bufioReaderPool.Put(br) -} - -var bufioWriterPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriter(nil) - }, -} - -func getBufioWriter(w io.Writer) *bufio.Writer { - bw := bufioWriterPool.Get().(*bufio.Writer) - bw.Reset(w) - return bw -} - -func returnBufioWriter(bw *bufio.Writer) { - bufioWriterPool.Put(bw) -} - -func makeSecWebSocketKey() string { - b := make([]byte, 16) - rand.Read(b) - return base64.StdEncoding.EncodeToString(b) -} diff --git a/dial_test.go b/dial_test.go deleted file mode 100644 index 083b9bf3ed6bea5984a751f1e2a5fac7954d757a..0000000000000000000000000000000000000000 --- a/dial_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// +build !js - -package websocket - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func TestBadDials(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - url string - opts *DialOptions - }{ - { - name: "badURL", - url: "://noscheme", - }, - { - name: "badURLScheme", - url: "ftp://nhooyr.io", - }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, - { - name: "badTLS", - url: "wss://totallyfake.nhooyr.io", - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - _, _, err := Dial(ctx, tc.url, tc.opts) - if err == nil { - t.Fatalf("expected non nil error: %+v", err) - } - }) - } -} - -func Test_verifyServerHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - response func(w http.ResponseWriter) - success bool - }{ - { - name: "badStatus", - response: func(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) - }, - success: false, - }, - { - name: "badConnection", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badUpgrade", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badSecWebSocketAccept", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Accept", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badSecWebSocketProtocol", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Protocol", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "success", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: true, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - w := httptest.NewRecorder() - tc.response(w) - resp := w.Result() - - r := httptest.NewRequest("GET", "/", nil) - key := makeSecWebSocketKey() - r.Header.Set("Sec-WebSocket-Key", key) - - if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) - } - - err := verifyServerResponse(r, resp) - if (err == nil) != tc.success { - t.Fatalf("unexpected error: %+v", err) - } - }) - } -} diff --git a/doc.go b/doc.go index 4c07d37ab10173cc1866b3ab1d6af298b229bc7b..da6f32227207318d0ffa0aaa8cd3677d144a386f 100644 --- a/doc.go +++ b/doc.go @@ -22,13 +22,12 @@ // // The client side fully supports compiling to WASM. // It wraps the WebSocket browser API. +// // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // // Thus the unsupported features when compiling to WASM are: -// - Accept API -// - Reader/Writer API -// - SetReadLimit -// - Ping +// - Accept and AcceptOptions +// - Conn's Reader, Writer, SetReadLimit, Ping methods // - HTTPClient and HTTPHeader dial options // // The *http.Response returned by Dial will always either be nil or &http.Response{} as diff --git a/frame.go b/frame.go new file mode 100644 index 0000000000000000000000000000000000000000..10cb9e38e7de502ac6d6be7c99990ed3f1ecc4de --- /dev/null +++ b/frame.go @@ -0,0 +1,423 @@ +package websocket + +import ( + "encoding/binary" + "fmt" + "io" + "math" +) + +//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode,MessageType,StatusCode -output=frame_string.go + +// opcode represents a WebSocket Opcode. +type opcode int + +// opcode constants. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + +func (o opcode) controlOp() bool { + switch o { + case opClose, opPing, opPong: + return true + } + return false +} + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like Protobufs. + MessageBinary +) + +// First byte contains fin, rsv1, rsv2, rsv3. +// Second byte contains mask flag and payload len. +// Next 8 bytes are the maximum extended payload length. +// Last 4 bytes are the mask key. +// https://tools.ietf.org/html/rfc6455#section-5.2 +const maxHeaderSize = 1 + 1 + 8 + 4 + +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type header struct { + fin bool + rsv1 bool + rsv2 bool + rsv3 bool + opcode opcode + + payloadLength int64 + + masked bool + maskKey [4]byte +} + +func makeWriteHeaderBuf() []byte { + return make([]byte, maxHeaderSize) +} + +// bytes returns the bytes of the header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func writeHeader(b []byte, h header) []byte { + if b == nil { + b = makeWriteHeaderBuf() + } + + b = b[:2] + b[0] = 0 + + if h.fin { + b[0] |= 1 << 7 + } + if h.rsv1 { + b[0] |= 1 << 6 + } + if h.rsv2 { + b[0] |= 1 << 5 + } + if h.rsv3 { + b[0] |= 1 << 4 + } + + b[0] |= byte(h.opcode) + + switch { + case h.payloadLength < 0: + panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) + case h.payloadLength <= 125: + b[1] = byte(h.payloadLength) + case h.payloadLength <= math.MaxUint16: + b[1] = 126 + b = b[:len(b)+2] + binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) + default: + b[1] = 127 + b = b[:len(b)+8] + binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength)) + } + + if h.masked { + b[1] |= 1 << 7 + b = b[:len(b)+4] + copy(b[len(b)-4:], h.maskKey[:]) + } + + return b +} + +func makeReadHeaderBuf() []byte { + return make([]byte, maxHeaderSize-2) +} + +// readHeader reads a header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func readHeader(b []byte, r io.Reader) (header, error) { + if b == nil { + b = makeReadHeaderBuf() + } + + // We read the first two bytes first so that we know + // exactly how long the header is. + b = b[:2] + _, err := io.ReadFull(r, b) + if err != nil { + return header{}, err + } + + var h header + h.fin = b[0]&(1<<7) != 0 + h.rsv1 = b[0]&(1<<6) != 0 + h.rsv2 = b[0]&(1<<5) != 0 + h.rsv3 = b[0]&(1<<4) != 0 + + h.opcode = opcode(b[0] & 0xf) + + var extra int + + h.masked = b[1]&(1<<7) != 0 + if h.masked { + extra += 4 + } + + payloadLength := b[1] &^ (1 << 7) + switch { + case payloadLength < 126: + h.payloadLength = int64(payloadLength) + case payloadLength == 126: + extra += 2 + case payloadLength == 127: + extra += 8 + } + + if extra == 0 { + return h, nil + } + + b = b[:extra] + _, err = io.ReadFull(r, b) + if err != nil { + return header{}, err + } + + switch { + case payloadLength == 126: + h.payloadLength = int64(binary.BigEndian.Uint16(b)) + b = b[2:] + case payloadLength == 127: + h.payloadLength = int64(binary.BigEndian.Uint64(b)) + if h.payloadLength < 0 { + return header{}, fmt.Errorf("header with negative payload length: %v", h.payloadLength) + } + b = b[8:] + } + + if h.masked { + copy(h.maskKey[:], b) + } + + return h, nil +} + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// These codes were retrieved from: +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +const ( + StatusNormalClosure StatusCode = 1000 + iota + StatusGoingAway + StatusProtocolError + StatusUnsupportedData + + _ // 1004 is reserved. + + StatusNoStatusRcvd + + // This StatusCode is only exported for use with WASM. + // In non WASM Go, the returned error will indicate whether the connection was closed or not or what happened. + StatusAbnormalClosure + + StatusInvalidFramePayloadData + StatusPolicyViolation + StatusMessageTooBig + StatusMandatoryExtension + StatusInternalError + StatusServiceRestart + StatusTryAgainLater + StatusBadGateway + + // This StatusCode is only exported for use with WASM. + // In non WASM Go, the returned error will indicate whether there was a TLS handshake failure. + StatusTLSHandshake +) + +// CloseError represents a WebSocket close frame. +// It is returned by Conn's methods when a WebSocket close frame is received from +// the peer. +// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, +// to check for this error. See the CloseError example. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case 1004, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +const maxControlFramePayload = 125 + +func (ce CloseError) bytes() ([]byte, error) { + if len(ce.Reason) > maxControlFramePayload-2 { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) + } + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +// xor applies the WebSocket masking algorithm to p +// with the given key where the first 3 bits of pos +// are the starting position in the key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the position of the next byte +// to be used for masking in the key. This is so that +// unmasking can be performed without the entire frame. +func fastXOR(key [4]byte, keyPos int, b []byte) int { + // If the payload is greater than or equal to 16 bytes, then it's worth + // masking 8 bytes at a time. + // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 + if len(b) >= 16 { + // We first create a key that is 8 bytes long + // and is aligned on the position correctly. + var alignedKey [8]byte + for i := range alignedKey { + alignedKey[i] = key[(i+keyPos)&3] + } + k := binary.LittleEndian.Uint64(alignedKey[:]) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^k) + v = binary.LittleEndian.Uint64(b[8:]) + binary.LittleEndian.PutUint64(b[8:], v^k) + v = binary.LittleEndian.Uint64(b[16:]) + binary.LittleEndian.PutUint64(b[16:], v^k) + v = binary.LittleEndian.Uint64(b[24:]) + binary.LittleEndian.PutUint64(b[24:], v^k) + v = binary.LittleEndian.Uint64(b[32:]) + binary.LittleEndian.PutUint64(b[32:], v^k) + v = binary.LittleEndian.Uint64(b[40:]) + binary.LittleEndian.PutUint64(b[40:], v^k) + v = binary.LittleEndian.Uint64(b[48:]) + binary.LittleEndian.PutUint64(b[48:], v^k) + v = binary.LittleEndian.Uint64(b[56:]) + binary.LittleEndian.PutUint64(b[56:], v^k) + v = binary.LittleEndian.Uint64(b[64:]) + binary.LittleEndian.PutUint64(b[64:], v^k) + v = binary.LittleEndian.Uint64(b[72:]) + binary.LittleEndian.PutUint64(b[72:], v^k) + v = binary.LittleEndian.Uint64(b[80:]) + binary.LittleEndian.PutUint64(b[80:], v^k) + v = binary.LittleEndian.Uint64(b[88:]) + binary.LittleEndian.PutUint64(b[88:], v^k) + v = binary.LittleEndian.Uint64(b[96:]) + binary.LittleEndian.PutUint64(b[96:], v^k) + v = binary.LittleEndian.Uint64(b[104:]) + binary.LittleEndian.PutUint64(b[104:], v^k) + v = binary.LittleEndian.Uint64(b[112:]) + binary.LittleEndian.PutUint64(b[112:], v^k) + v = binary.LittleEndian.Uint64(b[120:]) + binary.LittleEndian.PutUint64(b[120:], v^k) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^k) + v = binary.LittleEndian.Uint64(b[8:]) + binary.LittleEndian.PutUint64(b[8:], v^k) + v = binary.LittleEndian.Uint64(b[16:]) + binary.LittleEndian.PutUint64(b[16:], v^k) + v = binary.LittleEndian.Uint64(b[24:]) + binary.LittleEndian.PutUint64(b[24:], v^k) + v = binary.LittleEndian.Uint64(b[32:]) + binary.LittleEndian.PutUint64(b[32:], v^k) + v = binary.LittleEndian.Uint64(b[40:]) + binary.LittleEndian.PutUint64(b[40:], v^k) + v = binary.LittleEndian.Uint64(b[48:]) + binary.LittleEndian.PutUint64(b[48:], v^k) + v = binary.LittleEndian.Uint64(b[56:]) + binary.LittleEndian.PutUint64(b[56:], v^k) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^k) + v = binary.LittleEndian.Uint64(b[8:]) + binary.LittleEndian.PutUint64(b[8:], v^k) + v = binary.LittleEndian.Uint64(b[16:]) + binary.LittleEndian.PutUint64(b[16:], v^k) + v = binary.LittleEndian.Uint64(b[24:]) + binary.LittleEndian.PutUint64(b[24:], v^k) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^k) + v = binary.LittleEndian.Uint64(b[8:]) + binary.LittleEndian.PutUint64(b[8:], v^k) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^k) + b = b[8:] + } + } + + // xor remaining bytes. + for i := range b { + b[i] ^= key[keyPos&3] + keyPos++ + } + return keyPos & 3 +} diff --git a/statuscode_string.go b/frame_string.go similarity index 51% rename from statuscode_string.go rename to frame_string.go index fc8cea0d6faa717bb71e2c3ce1f4426d635ffa76..6b32672a564ae51ac8d5e17e324fee314b16810e 100644 --- a/statuscode_string.go +++ b/frame_string.go @@ -1,9 +1,61 @@ -// Code generated by "stringer -type=StatusCode"; DO NOT EDIT. +// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=frame_string.go"; DO NOT EDIT. package websocket import "strconv" +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[opContinuation-0] + _ = x[opText-1] + _ = x[opBinary-2] + _ = x[opClose-8] + _ = x[opPing-9] + _ = x[opPong-10] +} + +const ( + _opcode_name_0 = "opContinuationopTextopBinary" + _opcode_name_1 = "opCloseopPingopPong" +) + +var ( + _opcode_index_0 = [...]uint8{0, 14, 20, 28} + _opcode_index_1 = [...]uint8{0, 7, 13, 19} +) + +func (i opcode) String() string { + switch { + case 0 <= i && i <= 2: + return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] + case 8 <= i && i <= 10: + i -= 8 + return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] + default: + return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[MessageText-1] + _ = x[MessageBinary-2] +} + +const _MessageType_name = "MessageTextMessageBinary" + +var _MessageType_index = [...]uint8{0, 11, 24} + +func (i MessageType) String() string { + i -= 1 + if i < 0 || i >= MessageType(len(_MessageType_index)-1) { + return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] +} func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. diff --git a/frame_test.go b/frame_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1a2054c12226a7729717e578d67a5848844ce0fe --- /dev/null +++ b/frame_test.go @@ -0,0 +1,373 @@ +// +build !js + +package websocket + +import ( + "bytes" + "io" + "math" + "math/rand" + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func randBool() bool { + return rand.Intn(1) == 0 +} + +func TestHeader(t *testing.T) { + t.Parallel() + + t.Run("eof", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + bytes []byte + }{ + { + "start", + []byte{0xff}, + }, + { + "middle", + []byte{0xff, 0xff, 0xff}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := bytes.NewBuffer(tc.bytes) + _, err := readHeader(nil, b) + if io.ErrUnexpectedEOF != err { + t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) + } + }) + } + }) + + t.Run("writeNegativeLength", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + t.Fatal("failed to induce panic in writeHeader with negative payload length") + } + }() + + writeHeader(nil, header{ + payloadLength: -1, + }) + }) + + t.Run("readNegativeLength", func(t *testing.T) { + t.Parallel() + + b := writeHeader(nil, header{ + payloadLength: 1<<16 + 1, + }) + + // Make length negative + b[2] |= 1 << 7 + + r := bytes.NewReader(b) + _, err := readHeader(nil, r) + if err == nil { + t.Fatalf("unexpected error value: %+v", err) + } + }) + + t.Run("lengths", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 124, + 125, + 126, + 4096, + 16384, + 65535, + 65536, + 65537, + 131072, + } + + for _, n := range lengths { + n := n + t.Run(strconv.Itoa(n), func(t *testing.T) { + t.Parallel() + + testHeader(t, header{ + payloadLength: int64(n), + }) + }) + } + }) + + t.Run("fuzz", func(t *testing.T) { + t.Parallel() + + for i := 0; i < 10000; i++ { + h := header{ + fin: randBool(), + rsv1: randBool(), + rsv2: randBool(), + rsv3: randBool(), + opcode: opcode(rand.Intn(1 << 4)), + + masked: randBool(), + payloadLength: rand.Int63(), + } + + if h.masked { + rand.Read(h.maskKey[:]) + } + + testHeader(t, h) + } + }) +} + +func testHeader(t *testing.T, h header) { + b := writeHeader(nil, h) + r := bytes.NewReader(b) + h2, err := readHeader(nil, r) + if err != nil { + t.Logf("header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("failed to read header: %v", err) + } + + if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { + t.Logf("header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) + } +} + +func TestCloseError(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + ce CloseError + success bool + }{ + { + name: "normal", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", maxControlFramePayload-2), + }, + success: true, + }, + { + name: "bigReason", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", maxControlFramePayload-1), + }, + success: false, + }, + { + name: "bigCode", + ce: CloseError{ + Code: math.MaxUint16, + Reason: strings.Repeat("x", maxControlFramePayload-2), + }, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := tc.ce.bytes() + if (err == nil) != tc.success { + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if (err == nil) != tc.success { + t.Fatalf("unexpected expected error value: %+v", err) + } + + if tc.success && tc.ce != ce { + t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if valid := validWireCloseCode(tc.code); tc.valid != valid { + t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) + } + }) + } +} + +func Test_xor(t *testing.T) { + t.Parallel() + + key := [4]byte{0xa, 0xb, 0xc, 0xff} + p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} + pos := 0 + pos = fastXOR(key, pos, p) + + if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { + t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) + } + + if exp := 1; !cmp.Equal(exp, pos) { + t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos)) + } +} + +func basixXOR(maskKey [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} + +func BenchmarkXOR(b *testing.B) { + sizes := []int{ + 2, + 16, + 32, + 512, + 4096, + 16384, + } + + fns := []struct { + name string + fn func([4]byte, int, []byte) int + }{ + { + "basic", + basixXOR, + }, + { + "fast", + fastXOR, + }, + } + + var maskKey [4]byte + _, err := rand.Read(maskKey[:]) + if err != nil { + b.Fatalf("failed to populate mask key: %v", err) + } + + for _, size := range sizes { + data := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + fn.fn(maskKey, 0, data) + } + }) + } + }) + } +} diff --git a/accept.go b/handshake.go similarity index 54% rename from accept.go rename to handshake.go index e68a049b32987c4e7eaebedc4fd4a4a48cad1a95..c55dd463a083e2d954e3a66c3ba076e650e7166a 100644 --- a/accept.go +++ b/handshake.go @@ -3,16 +3,21 @@ package websocket import ( + "bufio" "bytes" + "context" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" + "io/ioutil" + "math/rand" "net/http" "net/textproto" "net/url" "strings" + "sync" ) // AcceptOptions represents the options available to pass to Accept. @@ -221,3 +226,185 @@ func authenticateOrigin(r *http.Request) error { } return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } + +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // HTTPClient is the http client used for the handshake. + // Its Transport must return writable bodies + // for WebSocket handshakes. + // http.Transport does this correctly beginning with Go 1.12. + HTTPClient *http.Client + + // HTTPHeader specifies the HTTP headers included in the handshake request. + HTTPHeader http.Header + + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string +} + +// Dial performs a WebSocket handshake on the given url with the given options. +// The response is the WebSocket handshake response from the server. +// If an error occurs, the returned response may be non nil. However, you can only +// read the first 1024 bytes of its body. +// +// You never need to close the resp.Body yourself. +// +// This function requires at least Go 1.12 to succeed as it uses a new feature +// in net/http to perform WebSocket handshakes and get a writable body +// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + c, r, err := dial(ctx, u, opts) + if err != nil { + return nil, r, fmt.Errorf("failed to websocket dial: %w", err) + } + return c, r, nil +} + +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + if opts == nil { + opts = &DialOptions{} + } + + // Shallow copy to ensure defaults do not affect user passed options. + opts2 := *opts + opts = &opts2 + + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } + if opts.HTTPClient.Timeout > 0 { + return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + + parsedURL, err := url.Parse(u) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse url: %w", err) + } + + switch parsedURL.Scheme { + case "ws": + parsedURL.Scheme = "http" + case "wss": + parsedURL.Scheme = "https" + default: + return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) + } + + req, _ := http.NewRequest("GET", parsedURL.String(), nil) + req = req.WithContext(ctx) + req.Header = opts.HTTPHeader + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey()) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) + } + defer func() { + if err != nil { + // We read a bit of the body for easier debugging. + r := io.LimitReader(resp.Body, 1024) + b, _ := ioutil.ReadAll(r) + resp.Body.Close() + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + }() + + err = verifyServerResponse(req, resp) + if err != nil { + return nil, resp, err + } + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) + } + + c := &Conn{ + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), + closer: rwc, + client: true, + } + c.extractBufioWriterBuf(rwc) + c.init() + + return c, resp, nil +} + +func verifyServerResponse(r *http.Request, resp *http.Response) error { + if resp.StatusCode != http.StatusSwitchingProtocols { + return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { + return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + } + + if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { + return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + } + + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { + return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + r.Header.Get("Sec-WebSocket-Key"), + ) + } + + if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { + return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + } + + return nil +} + +// The below pools can only be used by the client because http.Hijacker will always +// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top. + +var bufioReaderPool = sync.Pool{ + New: func() interface{} { + return bufio.NewReader(nil) + }, +} + +func getBufioReader(r io.Reader) *bufio.Reader { + br := bufioReaderPool.Get().(*bufio.Reader) + br.Reset(r) + return br +} + +func returnBufioReader(br *bufio.Reader) { + bufioReaderPool.Put(br) +} + +var bufioWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriter(nil) + }, +} + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw := bufioWriterPool.Get().(*bufio.Writer) + bw.Reset(w) + return bw +} + +func returnBufioWriter(bw *bufio.Writer) { + bufioWriterPool.Put(bw) +} + +func makeSecWebSocketKey() string { + b := make([]byte, 16) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} diff --git a/accept_test.go b/handshake_test.go similarity index 63% rename from accept_test.go rename to handshake_test.go index 44a956a85c4bf7f16a7c6abdfa029de04810626a..a3d98163f3d9bab3d88321a69f814d8f0768cb11 100644 --- a/accept_test.go +++ b/handshake_test.go @@ -3,9 +3,12 @@ package websocket import ( + "context" + "net/http" "net/http/httptest" "strings" "testing" + "time" ) func TestAccept(t *testing.T) { @@ -243,3 +246,138 @@ func Test_authenticateOrigin(t *testing.T) { }) } } + +func TestBadDials(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + }, + }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, tc.url, tc.opts) + if err == nil { + t.Fatalf("expected non nil error: %+v", err) + } + }) + } +} + +func Test_verifyServerHandshake(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + response func(w http.ResponseWriter) + success bool + }{ + { + name: "badStatus", + response: func(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + }, + success: false, + }, + { + name: "badConnection", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badUpgrade", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketAccept", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Accept", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketProtocol", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Protocol", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "success", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + tc.response(w) + resp := w.Result() + + r := httptest.NewRequest("GET", "/", nil) + key := makeSecWebSocketKey() + r.Header.Set("Sec-WebSocket-Key", key) + + if resp.Header.Get("Sec-WebSocket-Accept") == "" { + resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + } + + err := verifyServerResponse(r, resp) + if (err == nil) != tc.success { + t.Fatalf("unexpected error: %+v", err) + } + }) + } +} diff --git a/header.go b/header.go deleted file mode 100644 index 613b1d1510ffca71cdaf800a0d2cbbbec20c6833..0000000000000000000000000000000000000000 --- a/header.go +++ /dev/null @@ -1,158 +0,0 @@ -// +build !js - -package websocket - -import ( - "encoding/binary" - "fmt" - "io" - "math" -) - -// First byte contains fin, rsv1, rsv2, rsv3. -// Second byte contains mask flag and payload len. -// Next 8 bytes are the maximum extended payload length. -// Last 4 bytes are the mask key. -// https://tools.ietf.org/html/rfc6455#section-5.2 -const maxHeaderSize = 1 + 1 + 8 + 4 - -// header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type header struct { - fin bool - rsv1 bool - rsv2 bool - rsv3 bool - opcode opcode - - payloadLength int64 - - masked bool - maskKey [4]byte -} - -func makeWriteHeaderBuf() []byte { - return make([]byte, maxHeaderSize) -} - -// bytes returns the bytes of the header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func writeHeader(b []byte, h header) []byte { - if b == nil { - b = makeWriteHeaderBuf() - } - - b = b[:2] - b[0] = 0 - - if h.fin { - b[0] |= 1 << 7 - } - if h.rsv1 { - b[0] |= 1 << 6 - } - if h.rsv2 { - b[0] |= 1 << 5 - } - if h.rsv3 { - b[0] |= 1 << 4 - } - - b[0] |= byte(h.opcode) - - switch { - case h.payloadLength < 0: - panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) - case h.payloadLength <= 125: - b[1] = byte(h.payloadLength) - case h.payloadLength <= math.MaxUint16: - b[1] = 126 - b = b[:len(b)+2] - binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) - default: - b[1] = 127 - b = b[:len(b)+8] - binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength)) - } - - if h.masked { - b[1] |= 1 << 7 - b = b[:len(b)+4] - copy(b[len(b)-4:], h.maskKey[:]) - } - - return b -} - -func makeReadHeaderBuf() []byte { - return make([]byte, maxHeaderSize-2) -} - -// readHeader reads a header from the reader. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func readHeader(b []byte, r io.Reader) (header, error) { - if b == nil { - b = makeReadHeaderBuf() - } - - // We read the first two bytes first so that we know - // exactly how long the header is. - b = b[:2] - _, err := io.ReadFull(r, b) - if err != nil { - return header{}, err - } - - var h header - h.fin = b[0]&(1<<7) != 0 - h.rsv1 = b[0]&(1<<6) != 0 - h.rsv2 = b[0]&(1<<5) != 0 - h.rsv3 = b[0]&(1<<4) != 0 - - h.opcode = opcode(b[0] & 0xf) - - var extra int - - h.masked = b[1]&(1<<7) != 0 - if h.masked { - extra += 4 - } - - payloadLength := b[1] &^ (1 << 7) - switch { - case payloadLength < 126: - h.payloadLength = int64(payloadLength) - case payloadLength == 126: - extra += 2 - case payloadLength == 127: - extra += 8 - } - - if extra == 0 { - return h, nil - } - - b = b[:extra] - _, err = io.ReadFull(r, b) - if err != nil { - return header{}, err - } - - switch { - case payloadLength == 126: - h.payloadLength = int64(binary.BigEndian.Uint16(b)) - b = b[2:] - case payloadLength == 127: - h.payloadLength = int64(binary.BigEndian.Uint64(b)) - if h.payloadLength < 0 { - return header{}, fmt.Errorf("header with negative payload length: %v", h.payloadLength) - } - b = b[8:] - } - - if h.masked { - copy(h.maskKey[:], b) - } - - return h, nil -} diff --git a/header_test.go b/header_test.go deleted file mode 100644 index 5d0fd6a264fcdbacd9f6ef6cd2ce719929f9fb44..0000000000000000000000000000000000000000 --- a/header_test.go +++ /dev/null @@ -1,155 +0,0 @@ -// +build !js - -package websocket - -import ( - "bytes" - "io" - "math/rand" - "strconv" - "testing" - "time" - - "github.com/google/go-cmp/cmp" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBool() bool { - return rand.Intn(1) == 0 -} - -func TestHeader(t *testing.T) { - t.Parallel() - - t.Run("eof", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - bytes []byte - }{ - { - "start", - []byte{0xff}, - }, - { - "middle", - []byte{0xff, 0xff, 0xff}, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - b := bytes.NewBuffer(tc.bytes) - _, err := readHeader(nil, b) - if io.ErrUnexpectedEOF != err { - t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) - } - }) - } - }) - - t.Run("writeNegativeLength", func(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r == nil { - t.Fatal("failed to induce panic in writeHeader with negative payload length") - } - }() - - writeHeader(nil, header{ - payloadLength: -1, - }) - }) - - t.Run("readNegativeLength", func(t *testing.T) { - t.Parallel() - - b := writeHeader(nil, header{ - payloadLength: 1<<16 + 1, - }) - - // Make length negative - b[2] |= 1 << 7 - - r := bytes.NewReader(b) - _, err := readHeader(nil, r) - if err == nil { - t.Fatalf("unexpected error value: %+v", err) - } - }) - - t.Run("lengths", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 124, - 125, - 126, - 4096, - 16384, - 65535, - 65536, - 65537, - 131072, - } - - for _, n := range lengths { - n := n - t.Run(strconv.Itoa(n), func(t *testing.T) { - t.Parallel() - - testHeader(t, header{ - payloadLength: int64(n), - }) - }) - } - }) - - t.Run("fuzz", func(t *testing.T) { - t.Parallel() - - for i := 0; i < 10000; i++ { - h := header{ - fin: randBool(), - rsv1: randBool(), - rsv2: randBool(), - rsv3: randBool(), - opcode: opcode(rand.Intn(1 << 4)), - - masked: randBool(), - payloadLength: rand.Int63(), - } - - if h.masked { - rand.Read(h.maskKey[:]) - } - - testHeader(t, h) - } - }) -} - -func testHeader(t *testing.T, h header) { - b := writeHeader(nil, h) - r := bytes.NewReader(b) - h2, err := readHeader(nil, r) - if err != nil { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read header: %v", err) - } - - if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) - } -} diff --git a/messagetype.go b/messagetype.go deleted file mode 100644 index d6436b0b89bcae649f15c5c11bd77f48c7fdfa68..0000000000000000000000000000000000000000 --- a/messagetype.go +++ /dev/null @@ -1,17 +0,0 @@ -package websocket - -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -//go:generate go run golang.org/x/tools/cmd/stringer -type=MessageType - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. - MessageBinary -) - -// Above I've explicitly included the types of the constants for stringer. diff --git a/messagetype_string.go b/messagetype_string.go deleted file mode 100644 index bc62db93b22341aa36a0eb73b51cdb0bcf36678f..0000000000000000000000000000000000000000 --- a/messagetype_string.go +++ /dev/null @@ -1,25 +0,0 @@ -// Code generated by "stringer -type=MessageType"; DO NOT EDIT. - -package websocket - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[MessageText-1] - _ = x[MessageBinary-2] -} - -const _MessageType_name = "MessageTextMessageBinary" - -var _MessageType_index = [...]uint8{0, 11, 24} - -func (i MessageType) String() string { - i -= 1 - if i < 0 || i >= MessageType(len(_MessageType_index)-1) { - return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] -} diff --git a/opcode.go b/opcode.go deleted file mode 100644 index df708aa0baa8e8c4630bab0201bed823ab2dd470..0000000000000000000000000000000000000000 --- a/opcode.go +++ /dev/null @@ -1,31 +0,0 @@ -package websocket - -// opcode represents a WebSocket Opcode. -type opcode int - -//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode -tags js - -// opcode constants. -const ( - opContinuation opcode = iota - opText - opBinary - // 3 - 7 are reserved for further non-control frames. - _ - _ - _ - _ - _ - opClose - opPing - opPong - // 11-16 are reserved for further control frames. -) - -func (o opcode) controlOp() bool { - switch o { - case opClose, opPing, opPong: - return true - } - return false -} diff --git a/opcode_string.go b/opcode_string.go deleted file mode 100644 index d7b88961e4765a28c102310ad9ea564f9c042dcc..0000000000000000000000000000000000000000 --- a/opcode_string.go +++ /dev/null @@ -1,39 +0,0 @@ -// Code generated by "stringer -type=opcode -tags js"; DO NOT EDIT. - -package websocket - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[opContinuation-0] - _ = x[opText-1] - _ = x[opBinary-2] - _ = x[opClose-8] - _ = x[opPing-9] - _ = x[opPong-10] -} - -const ( - _opcode_name_0 = "opContinuationopTextopBinary" - _opcode_name_1 = "opCloseopPingopPong" -) - -var ( - _opcode_index_0 = [...]uint8{0, 14, 20, 28} - _opcode_index_1 = [...]uint8{0, 7, 13, 19} -) - -func (i opcode) String() string { - switch { - case 0 <= i && i <= 2: - return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] - case 8 <= i && i <= 10: - i -= 8 - return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] - default: - return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" - } -} diff --git a/statuscode.go b/statuscode.go deleted file mode 100644 index e7bb94999b45fc0279feddce500458376cee767b..0000000000000000000000000000000000000000 --- a/statuscode.go +++ /dev/null @@ -1,113 +0,0 @@ -package websocket - -import ( - "encoding/binary" - "fmt" -) - -// StatusCode represents a WebSocket status code. -// https://tools.ietf.org/html/rfc6455#section-7.4 -type StatusCode int - -//go:generate go run golang.org/x/tools/cmd/stringer -type=StatusCode - -// These codes were retrieved from: -// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -const ( - StatusNormalClosure StatusCode = 1000 + iota - StatusGoingAway - StatusProtocolError - StatusUnsupportedData - - _ // 1004 is reserved. - - StatusNoStatusRcvd - - // This StatusCode is only exported for use with WASM. - // In pure Go, the returned error will indicate whether the connection was closed or not or what happened. - StatusAbnormalClosure - - StatusInvalidFramePayloadData - StatusPolicyViolation - StatusMessageTooBig - StatusMandatoryExtension - StatusInternalError - StatusServiceRestart - StatusTryAgainLater - StatusBadGateway - - // This StatusCode is only exported for use with WASM. - // In pure Go, the returned error will indicate whether there was a TLS handshake failure. - StatusTLSHandshake -) - -// CloseError represents a WebSocket close frame. -// It is returned by Conn's methods when a WebSocket close frame is received from -// the peer. -// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, -// to check for this error. See the CloseError example. -type CloseError struct { - Code StatusCode - Reason string -} - -func (ce CloseError) Error() string { - return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case 1004, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -const maxControlFramePayload = 125 - -func (ce CloseError) bytes() ([]byte, error) { - if len(ce.Reason) > maxControlFramePayload-2 { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) - } - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} diff --git a/statuscode_test.go b/statuscode_test.go deleted file mode 100644 index b9637868361c4cc4400cf95af66fe4295b01f64a..0000000000000000000000000000000000000000 --- a/statuscode_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package websocket - -import ( - "math" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestCloseError(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - ce CloseError - success bool - }{ - { - name: "normal", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: true, - }, - { - name: "bigReason", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-1), - }, - success: false, - }, - { - name: "bigCode", - ce: CloseError{ - Code: math.MaxUint16, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - _, err := tc.ce.bytes() - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) - } - }) - } -} - -func Test_parseClosePayload(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - p []byte - success bool - ce CloseError - }{ - { - name: "normal", - p: append([]byte{0x3, 0xE8}, []byte("hello")...), - success: true, - ce: CloseError{ - Code: StatusNormalClosure, - Reason: "hello", - }, - }, - { - name: "nothing", - success: true, - ce: CloseError{ - Code: StatusNoStatusRcvd, - }, - }, - { - name: "oneByte", - p: []byte{0}, - success: false, - }, - { - name: "badStatusCode", - p: []byte{0x17, 0x70}, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ce, err := parseClosePayload(tc.p) - if (err == nil) != tc.success { - t.Fatalf("unexpected expected error value: %+v", err) - } - - if tc.success && tc.ce != ce { - t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) - } - }) - } -} - -func Test_validWireCloseCode(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - code StatusCode - valid bool - }{ - { - name: "normal", - code: StatusNormalClosure, - valid: true, - }, - { - name: "noStatus", - code: StatusNoStatusRcvd, - valid: false, - }, - { - name: "3000", - code: 3000, - valid: true, - }, - { - name: "4999", - code: 4999, - valid: true, - }, - { - name: "unknown", - code: 5000, - valid: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - if valid := validWireCloseCode(tc.code); tc.valid != valid { - t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) - } - }) - } -} diff --git a/websocket_autobahn_python_test.go b/websocket_autobahn_python_test.go deleted file mode 100644 index 62aa3f8e1ee26446735faf6ea45d8fa5642b55ea..0000000000000000000000000000000000000000 --- a/websocket_autobahn_python_test.go +++ /dev/null @@ -1,243 +0,0 @@ -// This file contains the old autobahn test suite tests that use the -// python binary. The approach is clunky and slow so new tests -// have been written in pure Go in websocket_test.go. -// These have been kept for correctness purposes and are occasionally ran. -// +build autobahn-python - -package websocket_test - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/httptest" - "os" - "os/exec" - "strconv" - "strings" - "testing" - "time" - - "nhooyr.io/websocket/internal/wsecho" -) - -// https://github.com/crossbario/autobahn-python/tree/master/wstest -func TestPythonAutobahnServer(t *testing.T) { - t.Parallel() - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := Accept(w, r, &AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - t.Logf("server handshake failed: %+v", err) - return - } - wsecho.Loop(r.Context(), c) - })) - defer s.Close() - - spec := map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - // 12.* and 13.* as we do not support compression. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} - wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() - if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) - } - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") -} - -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err - } - l.Close() - return l.Addr().String(), nil -} - -// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py -func TestPythonAutobahnClientOld(t *testing.T) { - t.Parallel() - - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) - } - - wsServerURL := "ws://" + serverAddr - - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), - // Disables some server that runs as part of fuzzingserver mode. - // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 - "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) - err = wstest.Start() - if err != nil { - t.Fatal(err) - } - defer func() { - err := wstest.Process.Kill() - if err != nil { - t.Error(err) - } - }() - - // Let it come up. - time.Sleep(time.Second * 5) - - var cases int - func() { - c, _, err := Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(StatusInternalError, "") - - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) - } - cases, err = strconv.Atoi(string(b)) - if err != nil { - t.Fatal(err) - } - - c.Close(StatusNormalClosure, "") - }() - - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) - defer cancel() - - c, _, err := Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } - wsecho.Loop(ctx, c) - }() - } - - c, _, err := Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) - if err != nil { - t.Fatal(err) - } - c.Close(StatusNormalClosure, "") - - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") -} - -func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } - - var indexJSON map[string]map[string]struct { - Behavior string `json:"behavior"` - BehaviorClose string `json:"behaviorClose"` - } - err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } - - var failed bool - for _, tests := range indexJSON { - for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) - } - } - } - - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) - } - } -} diff --git a/websocket_bench_test.go b/websocket_bench_test.go deleted file mode 100644 index ff2fd70416da5243cf90fb68678e914e2b645836..0000000000000000000000000000000000000000 --- a/websocket_bench_test.go +++ /dev/null @@ -1,148 +0,0 @@ -// +build !js - -package websocket_test - -import ( - "context" - "io" - "io/ioutil" - "net/http" - "strconv" - "strings" - "testing" - "time" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/wsecho" -) - -func BenchmarkConn(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } - - b.Run("write", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("stream", func(b *testing.B) { - benchConn(b, false, true, size) - }) - b.Run("buffer", func(b *testing.B) { - benchConn(b, false, false, size) - }) - }) - } - }) - - b.Run("echo", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) - }) - } - }) -} - -func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - if echo { - wsecho.Loop(r.Context(), c) - } else { - discardLoop(r.Context(), c) - } - return nil - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - - c, _, err := websocket.Dial(ctx, wsURL, nil) - if err != nil { - b.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - msg := []byte(strings.Repeat("2", size)) - readBuf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } - - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } - - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) - } - - _, err = io.ReadFull(r, readBuf) - if err != nil { - b.Fatal(err) - } - } - } - b.StopTimer() - - c.Close(websocket.StatusNormalClosure, "") -} - -func discardLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32768) - echo := func() error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - - _, err = io.CopyBuffer(ioutil.Discard, r, b) - if err != nil { - return err - } - return nil - } - - for { - err := echo() - if err != nil { - return - } - } -} diff --git a/websocket_js_test.go b/websocket_js_test.go index 1142190c2039f54d133abbfaa1b0d6f3b4ffaf38..e68ba6f3e0b1dd4faae9b1dfe1ab8cca154808f2 100644 --- a/websocket_js_test.go +++ b/websocket_js_test.go @@ -50,6 +50,4 @@ func TestConn(t *testing.T) { if err != nil { t.Fatal(err) } - - time.Sleep(time.Millisecond * 100) } diff --git a/xor.go b/xor.go deleted file mode 100644 index f9fe2051fceb16a186fe096b0071d1d7d1a87975..0000000000000000000000000000000000000000 --- a/xor.go +++ /dev/null @@ -1,127 +0,0 @@ -// +build !js - -package websocket - -import ( - "encoding/binary" -) - -// xor applies the WebSocket masking algorithm to p -// with the given key where the first 3 bits of pos -// are the starting position in the key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the position of the next byte -// to be used for masking in the key. This is so that -// unmasking can be performed without the entire frame. -func fastXOR(key [4]byte, keyPos int, b []byte) int { - // If the payload is greater than or equal to 16 bytes, then it's worth - // masking 8 bytes at a time. - // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 - if len(b) >= 16 { - // We first create a key that is 8 bytes long - // and is aligned on the position correctly. - var alignedKey [8]byte - for i := range alignedKey { - alignedKey[i] = key[(i+keyPos)&3] - } - k := binary.LittleEndian.Uint64(alignedKey[:]) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) - v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) - v = binary.LittleEndian.Uint64(b[16:]) - binary.LittleEndian.PutUint64(b[16:], v^k) - v = binary.LittleEndian.Uint64(b[24:]) - binary.LittleEndian.PutUint64(b[24:], v^k) - v = binary.LittleEndian.Uint64(b[32:]) - binary.LittleEndian.PutUint64(b[32:], v^k) - v = binary.LittleEndian.Uint64(b[40:]) - binary.LittleEndian.PutUint64(b[40:], v^k) - v = binary.LittleEndian.Uint64(b[48:]) - binary.LittleEndian.PutUint64(b[48:], v^k) - v = binary.LittleEndian.Uint64(b[56:]) - binary.LittleEndian.PutUint64(b[56:], v^k) - v = binary.LittleEndian.Uint64(b[64:]) - binary.LittleEndian.PutUint64(b[64:], v^k) - v = binary.LittleEndian.Uint64(b[72:]) - binary.LittleEndian.PutUint64(b[72:], v^k) - v = binary.LittleEndian.Uint64(b[80:]) - binary.LittleEndian.PutUint64(b[80:], v^k) - v = binary.LittleEndian.Uint64(b[88:]) - binary.LittleEndian.PutUint64(b[88:], v^k) - v = binary.LittleEndian.Uint64(b[96:]) - binary.LittleEndian.PutUint64(b[96:], v^k) - v = binary.LittleEndian.Uint64(b[104:]) - binary.LittleEndian.PutUint64(b[104:], v^k) - v = binary.LittleEndian.Uint64(b[112:]) - binary.LittleEndian.PutUint64(b[112:], v^k) - v = binary.LittleEndian.Uint64(b[120:]) - binary.LittleEndian.PutUint64(b[120:], v^k) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) - v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) - v = binary.LittleEndian.Uint64(b[16:]) - binary.LittleEndian.PutUint64(b[16:], v^k) - v = binary.LittleEndian.Uint64(b[24:]) - binary.LittleEndian.PutUint64(b[24:], v^k) - v = binary.LittleEndian.Uint64(b[32:]) - binary.LittleEndian.PutUint64(b[32:], v^k) - v = binary.LittleEndian.Uint64(b[40:]) - binary.LittleEndian.PutUint64(b[40:], v^k) - v = binary.LittleEndian.Uint64(b[48:]) - binary.LittleEndian.PutUint64(b[48:], v^k) - v = binary.LittleEndian.Uint64(b[56:]) - binary.LittleEndian.PutUint64(b[56:], v^k) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) - v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) - v = binary.LittleEndian.Uint64(b[16:]) - binary.LittleEndian.PutUint64(b[16:], v^k) - v = binary.LittleEndian.Uint64(b[24:]) - binary.LittleEndian.PutUint64(b[24:], v^k) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) - v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) - b = b[8:] - } - } - - // xor remaining bytes. - for i := range b { - b[i] ^= key[keyPos&3] - keyPos++ - } - return keyPos & 3 -} diff --git a/xor_test.go b/xor_test.go deleted file mode 100644 index 70047a9cba2440bf6be246a3a486eb6edbf1d795..0000000000000000000000000000000000000000 --- a/xor_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// +build !js - -package websocket - -import ( - "crypto/rand" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func Test_xor(t *testing.T) { - t.Parallel() - - key := [4]byte{0xa, 0xb, 0xc, 0xff} - p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - pos := 0 - pos = fastXOR(key, pos, p) - - if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { - t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) - } - - if exp := 1; !cmp.Equal(exp, pos) { - t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos)) - } -} - -func basixXOR(maskKey [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= maskKey[pos&3] - pos++ - } - return pos & 3 -} - -func BenchmarkXOR(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } - - fns := []struct { - name string - fn func([4]byte, int, []byte) int - }{ - { - "basic", - basixXOR, - }, - { - "fast", - fastXOR, - }, - } - - var maskKey [4]byte - _, err := rand.Read(maskKey[:]) - if err != nil { - b.Fatalf("failed to populate mask key: %v", err) - } - - for _, size := range sizes { - data := make([]byte, size) - - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { - b.ReportAllocs() - b.SetBytes(int64(size)) - - for i := 0; i < b.N; i++ { - fn.fn(maskKey, 0, data) - } - }) - } - }) - } -}