good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 94f9b715 authored by Anmol Sethi's avatar Anmol Sethi Committed by GitHub
Browse files

Merge pull request #187 from nhooyr/release-v1.8.0

Release v1.8.0
parents b9610079 4735f367
Branches
Tags v1.8.0
No related merge requests found
......@@ -2,12 +2,6 @@ all: fmt lint test
.SILENT:
.PHONY: *
.ONESHELL:
SHELL = bash
.SHELLFLAGS = -ceuo pipefail
include ci/fmt.mk
include ci/lint.mk
include ci/test.mk
# websocket
[![release](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases)
[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket)
[![coverage](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket)
[![ci](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions)
websocket is a minimal and idiomatic WebSocket library for Go.
......@@ -17,7 +14,8 @@ go get nhooyr.io/websocket
- Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support
- Thorough tests, fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- Thorough unit tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket)
- [Minimal dependencies](https://godoc.org/nhooyr.io/websocket?imports)
- JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages
- Zero alloc reads and writes
......@@ -111,8 +109,7 @@ Advantages of nhooyr.io/websocket:
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- Uses [klauspost/compress](https://github.com/klauspost/compress) for optimized compression
- See [gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)
- We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203))
- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
......
......@@ -6,14 +6,15 @@ import (
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"golang.org/x/xerrors"
"nhooyr.io/websocket/internal/errd"
)
......@@ -85,7 +86,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
hj, ok := w.(http.Hijacker)
if !ok {
err = xerrors.New("http.ResponseWriter does not implement http.Hijacker")
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
return nil, err
}
......@@ -110,7 +111,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
netConn, brw, err := hj.Hijack()
if err != nil {
err = xerrors.Errorf("failed to hijack connection: %w", err)
err = fmt.Errorf("failed to hijack connection: %w", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil, err
}
......@@ -133,32 +134,32 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
if !r.ProtoAtLeast(1, 1) {
return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}
if !headerContainsToken(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}
if !headerContainsToken(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
}
if r.Method != "GET" {
return http.StatusMethodNotAllowed, xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
if r.Header.Get("Sec-WebSocket-Key") == "" {
return http.StatusBadRequest, xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
return 0, nil
......@@ -169,10 +170,10 @@ func authenticateOrigin(r *http.Request) error {
if origin != "" {
u, err := url.Parse(origin)
if err != nil {
return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if !strings.EqualFold(u.Host, r.Host) {
return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
}
}
return nil
......@@ -208,6 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
copts.serverMaxWindowBits = 8
for _, p := range ext.params {
switch p {
......@@ -219,11 +221,31 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
continue
}
if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") {
if strings.HasPrefix(p, "client_max_window_bits") {
continue
// bits, ok := parseExtensionParameter(p, 15)
// if !ok || bits < 8 || bits > 16 {
// err := fmt.Errorf("invalid client_max_window_bits: %q", p)
// http.Error(w, err.Error(), http.StatusBadRequest)
// return nil, err
// }
// copts.clientMaxWindowBits = bits
// continue
}
if false && strings.HasPrefix(p, "server_max_window_bits") {
// We always send back 8 but make sure to validate.
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
err := fmt.Errorf("invalid server_max_window_bits: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
continue
}
err := xerrors.Errorf("unsupported permessage-deflate parameter: %q", p)
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
......@@ -233,6 +255,21 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
return copts, nil
}
// parseExtensionParameter parses the value in the extension parameter p.
// It falls back to defaultVal if there is no value.
// If defaultVal == 0, then ok == false if there is no value.
func parseExtensionParameter(p string, defaultVal int) (int, bool) {
ps := strings.Split(p, "=")
if len(ps) == 1 {
if defaultVal > 0 {
return defaultVal, true
}
return 0, false
}
i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
return i, e == nil
}
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
// The peer must explicitly request it.
......@@ -253,7 +290,7 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com
//
// Either way, we're only implementing this for webkit which never sends the max_window_bits
// parameter so we don't need to worry about it.
err := xerrors.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
......
package websocket
import (
"errors"
"net/http"
"golang.org/x/xerrors"
)
// AcceptOptions represents Accept's options.
......@@ -16,5 +15,5 @@ type AcceptOptions struct {
// Accept is stubbed out for Wasm.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, xerrors.New("unimplemented")
return nil, errors.New("unimplemented")
}
......@@ -4,14 +4,13 @@ package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"golang.org/x/xerrors"
"nhooyr.io/websocket/internal/test/assert"
)
......@@ -80,7 +79,7 @@ func TestAccept(t *testing.T) {
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, xerrors.New("haha")
return nil, nil, errors.New("haha")
},
}
......@@ -328,6 +327,7 @@ func Test_acceptCompression(t *testing.T) {
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
serverMaxWindowBits: 8,
},
},
{
......
......@@ -15,8 +15,6 @@ import (
"testing"
"time"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert"
......@@ -108,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er
"exclude-cases": excludedAutobahnCases,
})
if err != nil {
return "", nil, xerrors.Errorf("failed to write spec: %w", err)
return "", nil, fmt.Errorf("failed to write spec: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
......@@ -126,7 +124,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er
wstest := exec.CommandContext(ctx, "wstest", args...)
err = wstest.Start()
if err != nil {
return "", nil, xerrors.Errorf("failed to start wstest: %w", err)
return "", nil, fmt.Errorf("failed to start wstest: %w", err)
}
return url, func() {
......@@ -209,7 +207,7 @@ func unusedListenAddr() (_ string, err error) {
func tempJSONFile(v interface{}) (string, error) {
f, err := ioutil.TempFile("", "temp.json")
if err != nil {
return "", xerrors.Errorf("temp file: %w", err)
return "", fmt.Errorf("temp file: %w", err)
}
defer f.Close()
......@@ -217,12 +215,12 @@ func tempJSONFile(v interface{}) (string, error) {
e.SetIndent("", "\t")
err = e.Encode(v)
if err != nil {
return "", xerrors.Errorf("json encode: %w", err)
return "", fmt.Errorf("json encode: %w", err)
}
err = f.Close()
if err != nil {
return "", xerrors.Errorf("close temp file: %w", err)
return "", fmt.Errorf("close temp file: %w", err)
}
return f.Name(), nil
......
#!/usr/bin/env bash
set -euo pipefail
main() {
local files
mapfile -t files < <(git ls-files --other --modified --exclude-standard)
if [[ ${files[*]} == "" ]]; then
return
fi
echo "Files need generation or are formatted incorrectly:"
for f in "${files[@]}"; do
echo " $f"
done
echo
echo "Please run the following locally:"
echo " make fmt"
exit 1
}
main "$@"
fmt: modtidy gofmt goimports prettier
fmt: modtidy gofmt goimports prettier shfmt
ifdef CI
if [[ $$(git ls-files --other --modified --exclude-standard) != "" ]]; then
echo "Files need generation or are formatted incorrectly:"
git -c color.ui=always status | grep --color=no '\e\[31m'
echo "Please run the following locally:"
echo " make fmt"
exit 1
fi
./ci/ensure_fmt.sh
endif
modtidy: gen
......@@ -23,3 +17,6 @@ prettier:
gen:
stringer -type=opcode,MessageType,StatusCode -output=stringer.go
shfmt:
shfmt -i 2 -w -s -sr $$(git ls-files "*.sh")
FROM golang:1
RUN apt-get update
RUN apt-get install -y chromium npm
RUN apt-get install -y chromium npm shellcheck
ARG SHFMT_URL=https://github.com/mvdan/sh/releases/download/v3.0.1/shfmt_v3.0.1_linux_amd64
RUN curl -L $SHFMT_URL > /usr/local/bin/shfmt && chmod +x /usr/local/bin/shfmt
ENV GOFLAGS="-mod=readonly"
ENV PAGER=cat
ENV CI=true
ENV MAKEFLAGS="--jobs=16 --output-sync=target"
......
lint: govet golint
lint: govet golint govet-wasm golint-wasm shellcheck
govet:
go vet ./...
......@@ -11,3 +11,6 @@ golint:
golint-wasm:
GOOS=js GOARCH=wasm golint -set_exit_status ./...
shellcheck:
shellcheck $$(git ls-files "*.sh")
......@@ -7,7 +7,6 @@ ci/out/coverage.html: gotest
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
coveralls: gotest
# https://github.com/coverallsapp/github-action/blob/master/src/run.ts
echo "--- coveralls"
goveralls -coverprofile=ci/out/coverage.prof
......
package websocket
import (
"errors"
"fmt"
"golang.org/x/xerrors"
)
// StatusCode represents a WebSocket status code.
......@@ -53,7 +52,7 @@ const (
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's xerrors.As to check for this error.
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
......@@ -64,13 +63,13 @@ func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's xerrors.As to grab
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if xerrors.As(err, &ce) {
if errors.As(err, &ce) {
return ce.Code
}
return -1
......
......@@ -5,11 +5,11 @@ package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"log"
"time"
"golang.org/x/xerrors"
"nhooyr.io/websocket/internal/errd"
)
......@@ -46,7 +46,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
return nil
}
var errAlreadyWroteClose = xerrors.New("already wrote close")
var errAlreadyWroteClose = errors.New("already wrote close")
func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
......@@ -62,7 +62,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error {
Reason: reason,
}
c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce))
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
var p []byte
var err error
......@@ -119,7 +119,7 @@ func parseClosePayload(p []byte) (CloseError, error) {
}
if len(p) < 2 {
return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
......@@ -128,7 +128,7 @@ func parseClosePayload(p []byte) (CloseError, error) {
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code)
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
......@@ -155,7 +155,7 @@ func validWireCloseCode(code StatusCode) bool {
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = xerrors.Errorf("failed to marshal close frame: %w", err)
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
......@@ -168,11 +168,11 @@ const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, xerrors.Errorf("status code %v cannot be set", ce.Code)
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
......@@ -189,7 +189,7 @@ func (c *Conn) setCloseErr(err error) {
func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil {
c.closeErr = xerrors.Errorf("WebSocket closed: %w", err)
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
}
......
......@@ -3,6 +3,7 @@
package websocket
import (
"fmt"
"io"
"net/http"
"sync"
......@@ -19,7 +20,10 @@ func (m CompressionMode) opts() *compressionOptions {
type compressionOptions struct {
clientNoContextTakeover bool
clientMaxWindowBits int
serverNoContextTakeover bool
serverMaxWindowBits int
}
func (copts *compressionOptions) setHeader(h http.Header) {
......@@ -30,6 +34,12 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
if false && copts.serverMaxWindowBits > 0 {
s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits)
}
if false && copts.clientMaxWindowBits > 0 {
s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits)
}
h.Set("Sec-WebSocket-Extensions", s)
}
......@@ -152,9 +162,8 @@ func (sw *slidingWindow) close() {
}
swPoolMu.Lock()
defer swPoolMu.Unlock()
swPool[cap(sw.buf)].Put(sw.buf)
swPoolMu.Unlock()
sw.buf = nil
}
......
......@@ -5,13 +5,13 @@ package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"runtime"
"strconv"
"sync"
"sync/atomic"
"golang.org/x/xerrors"
)
// Conn represents a WebSocket connection.
......@@ -108,7 +108,7 @@ func newConn(cfg connConfig) *Conn {
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close(xerrors.New("connection garbage collected"))
c.close(errors.New("connection garbage collected"))
})
go c.timeoutLoop()
......@@ -165,10 +165,10 @@ func (c *Conn) timeoutLoop() {
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, xerrors.New("timed out"))
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
case <-writeCtx.Done():
c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err()))
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
return
}
}
......@@ -190,7 +190,7 @@ func (c *Conn) Ping(ctx context.Context) error {
err := c.ping(ctx, strconv.Itoa(int(p)))
if err != nil {
return xerrors.Errorf("failed to ping: %w", err)
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
......@@ -217,7 +217,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
case <-c.closed:
return c.closeErr
case <-ctx.Done():
err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err())
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
case <-pong:
......@@ -242,7 +242,7 @@ func (m *mu) Lock(ctx context.Context) error {
case <-m.c.closed:
return m.c.closeErr
case <-ctx.Done():
err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err())
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
case m.ch <- struct{}{}:
......
......@@ -19,7 +19,6 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/duration"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/test/assert"
......@@ -115,13 +114,21 @@ func TestConn(t *testing.T) {
for i := 0; i < count; i++ {
go func() {
errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg)
select {
case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
case <-tt.ctx.Done():
return
}
}()
}
for i := 0; i < count; i++ {
err := <-errs
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
}
err := c1.Close(websocket.StatusNormalClosure, "")
......@@ -172,8 +179,12 @@ func TestConn(t *testing.T) {
_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)
err = <-errs
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
assert.Equal(t, "read msg", []byte("hello"), b)
})
......@@ -196,8 +207,12 @@ func TestConn(t *testing.T) {
_, err := ioutil.ReadAll(n1)
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
err = <-errs
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
})
t.Run("wsjson", func(t *testing.T) {
......@@ -219,8 +234,12 @@ func TestConn(t *testing.T) {
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
err = <-werr
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
......@@ -289,10 +308,10 @@ func TestWasm(t *testing.T) {
func assertCloseStatus(exp websocket.StatusCode, err error) error {
if websocket.CloseStatus(err) == -1 {
return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err)
return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
}
if websocket.CloseStatus(err) != exp {
return xerrors.Errorf("expected close status %v but got ", exp, err)
return fmt.Errorf("expected close status %v but got %v", exp, err)
}
return nil
}
......@@ -412,14 +431,22 @@ func BenchmarkConn(b *testing.B) {
go func() {
for range writes {
werrs <- c1.Write(bb.ctx, websocket.MessageText, msg)
select {
case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
case <-bb.ctx.Done():
return
}
}
}()
b.SetBytes(int64(len(msg)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
writes <- struct{}{}
select {
case writes <- struct{}{}:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
typ, r, err := c1.Reader(bb.ctx)
if err != nil {
......@@ -446,7 +473,11 @@ func BenchmarkConn(b *testing.B) {
assert.Equal(b, "msg", msg, readBuf)
}
err = <-werrs
select {
case err = <-werrs:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
if err != nil {
b.Fatal(err)
}
......
......@@ -8,14 +8,15 @@ import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"golang.org/x/xerrors"
"time"
"nhooyr.io/websocket/internal/errd"
)
......@@ -78,7 +79,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
return nil, nil, xerrors.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey)
......@@ -91,6 +92,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
})
defer timer.Stop()
b, _ := ioutil.ReadAll(r)
respBody.Close()
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
......@@ -104,7 +111,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
rwc, ok := respBody.(io.ReadWriteCloser)
if !ok {
return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
}
return newConn(connConfig{
......@@ -120,12 +127,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) {
if opts.HTTPClient.Timeout > 0 {
return nil, xerrors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
}
u, err := url.Parse(urls)
if err != nil {
return nil, xerrors.Errorf("failed to parse url: %w", err)
return nil, fmt.Errorf("failed to parse url: %w", err)
}
switch u.Scheme {
......@@ -134,7 +141,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
case "wss":
u.Scheme = "https"
default:
return nil, xerrors.Errorf("unexpected url scheme: %q", u.Scheme)
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
......@@ -148,12 +155,13 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
}
if opts.CompressionMode != CompressionDisabled {
copts := opts.CompressionMode.opts()
copts.clientMaxWindowBits = 8
copts.setHeader(req.Header)
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, xerrors.Errorf("failed to send handshake request: %w", err)
return nil, fmt.Errorf("failed to send handshake request: %w", err)
}
return resp, nil
}
......@@ -165,26 +173,26 @@ func secWebSocketKey(rr io.Reader) (string, error) {
b := make([]byte, 16)
_, err := io.ReadFull(rr, b)
if err != nil {
return "", xerrors.Errorf("failed to read random data from rand.Reader: %w", err)
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
return nil, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
return nil, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, xerrors.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
......@@ -210,7 +218,7 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
}
}
return xerrors.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}
func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
......@@ -221,19 +229,40 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 {
return nil, xerrors.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}
copts := &compressionOptions{}
copts.clientMaxWindowBits = 8
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
default:
return nil, xerrors.Errorf("unsupported permessage-deflate parameter: %q", p)
continue
}
if false && strings.HasPrefix(p, "server_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid server_max_window_bits: %q", p)
}
copts.serverMaxWindowBits = bits
continue
}
if false && strings.HasPrefix(p, "client_max_window_bits") {
bits, ok := parseExtensionParameter(p, 0)
if !ok || bits < 8 || bits > 16 {
return nil, fmt.Errorf("invalid client_max_window_bits: %q", p)
}
copts.clientMaxWindowBits = 8
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
return copts, nil
......
......@@ -4,6 +4,7 @@ package websocket_test
import (
"context"
"errors"
"fmt"
"io"
"log"
......@@ -12,7 +13,6 @@ import (
"time"
"golang.org/x/time/rate"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
......@@ -78,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
if c.Subprotocol() != "echo" {
c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol")
return xerrors.New("client does not speak echo sub protocol")
return errors.New("client does not speak echo sub protocol")
}
l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
......@@ -88,7 +88,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
return nil
}
if err != nil {
return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err)
return fmt.Errorf("failed to echo with %v: %w", r.RemoteAddr, err)
}
}
}
......@@ -117,7 +117,7 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
_, err = io.Copy(w, r)
if err != nil {
return xerrors.Errorf("failed to io.Copy: %w", err)
return fmt.Errorf("failed to io.Copy: %w", err)
}
err = w.Close()
......
......@@ -3,12 +3,11 @@ package websocket
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"math/bits"
"golang.org/x/xerrors"
"nhooyr.io/websocket/internal/errd"
)
......@@ -87,7 +86,7 @@ func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
}
if h.payloadLength < 0 {
return header{}, xerrors.Errorf("received negative payload length: %v", h.payloadLength)
return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
}
if h.masked {
......
module nhooyr.io/websocket
go 1.12
go 1.13
require (
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect
......@@ -11,5 +11,4 @@ require (
github.com/gorilla/websocket v1.4.1
github.com/klauspost/compress v1.10.0
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment