diff --git a/ci/test.sh b/ci/test.sh index 3c476d93edb11cda154f1e0a661c06cfd27b0cf8..c8b8ec19922b13ced0585b6d63ac2b80240d0b99 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -12,14 +12,13 @@ argv=( -- "-vet=off" ) -# Interactive usage does not want to turn off vet or use gotestsum by default. +# Interactive usage does not want to turn off vet or use gotestsum. if [[ $# -gt 0 ]]; then argv=(go test "$@") fi # We always want coverage and race detection. argv+=( - -race "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) diff --git a/export_test.go b/export_test.go index fc885bffd092430565252225a4e07423333025b6..811bf800b94011740d6a9f491f9ff8f9bfd8be8b 100644 --- a/export_test.go +++ b/export_test.go @@ -1,27 +1,124 @@ package websocket import ( + "bufio" "context" + + "golang.org/x/xerrors" +) + +type ( + Addr = websocketAddr + OpCode int ) -type Addr = websocketAddr +const ( + OpClose = OpCode(opClose) + OpBinary = OpCode(opBinary) + OpText = OpCode(opText) + OpPing = OpCode(opPing) + OpPong = OpCode(opPong) + OpContinuation = OpCode(opContinuation) +) + +func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) { + h, err := c.readFrameHeader(ctx) + if err != nil { + return 0, nil, err + } + b := make([]byte, h.payloadLength) + _, err = c.readFramePayload(ctx, b) + if err != nil { + return 0, nil, err + } + if h.masked { + fastXOR(h.maskKey, 0, b) + } + return OpCode(h.opcode), b, nil +} + +func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) (int, error) { + return c.writeFrame(ctx, fin, opcode(opc), p) +} + +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + Rsv1 bool + Rsv2 bool + Rsv3 bool + OpCode OpCode + + PayloadLength int64 +} -const OPClose = opClose -const OPBinary = opBinary -const OPPing = opPing -const OPContinuation = opContinuation +func (c *Conn) WriteHeader(ctx context.Context, h Header) error { + headerBytes := writeHeader(c.writeHeaderBuf, header{ + fin: h.Fin, + rsv1: h.Rsv1, + rsv2: h.Rsv2, + rsv3: h.Rsv3, + opcode: opcode(h.OpCode), + payloadLength: h.PayloadLength, + masked: c.client, + }) + _, err := c.bw.Write(headerBytes) + if err != nil { + return xerrors.Errorf("failed to write header: %w", err) + } + if h.Fin { + err = c.Flush() + if err != nil { + return err + } + } + return nil +} -func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - return c.writeFrame(ctx, fin, opcode, p) +func (c *Conn) PingWithPayload(ctx context.Context, p string) error { + return c.ping(ctx, p) } func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { return c.realWriteFrame(ctx, header{ + fin: true, opcode: opBinary, - payloadLength: 5, - }, make([]byte, 10)) + payloadLength: 10, + }, make([]byte, 5)) +} + +func (c *Conn) CloseUnderlyingConn() { + c.closer.Close() } func (c *Conn) Flush() error { return c.bw.Flush() } + +func (c CloseError) Bytes() ([]byte, error) { + return c.bytes() +} + +func (c *Conn) BW() *bufio.Writer { + return c.bw +} + +func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) ([]byte, error) { + b, err := CloseError{ + Code: code, + Reason: reason, + }.Bytes() + if err != nil { + return nil, err + } + _, err = c.WriteFrame(ctx, true, OpClose, b) + if err != nil { + return nil, err + } + return b, nil +} + +func ParseClosePayload(p []byte) (CloseError, error) { + return parseClosePayload(p) +} diff --git a/go.mod b/go.mod index c9cc6fc426bfcce46a2b2bf367fc37009518f22d..70fe1d4c111e94136ef01f194866cf2e86228d77 100644 --- a/go.mod +++ b/go.mod @@ -3,18 +3,29 @@ module nhooyr.io/websocket go 1.12 require ( - github.com/golang/protobuf v1.3.1 - github.com/google/go-cmp v0.2.0 + github.com/fatih/color v1.7.0 // indirect + github.com/golang/protobuf v1.3.2 + github.com/google/go-cmp v0.3.1 + github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/kr/pretty v0.1.0 // indirect + github.com/mattn/go-colorable v0.1.2 // indirect + github.com/mattn/go-isatty v0.0.9 // indirect + github.com/pkg/errors v0.8.1 // indirect + github.com/sirupsen/logrus v1.4.2 // indirect + github.com/spf13/pflag v1.0.3 // indirect + github.com/stretchr/testify v1.4.0 // indirect go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16 go.uber.org/atomic v1.4.0 // indirect go.uber.org/multierr v1.1.0 golang.org/x/lint v0.0.0-20190409202823-959b441ac422 - golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 + golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 + golang.org/x/sys v0.0.0-20190830142957-1e83adbbebd0 // indirect golang.org/x/text v0.3.2 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - golang.org/x/tools v0.0.0-20190429184909-35c670923e21 - golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 + golang.org/x/tools v0.0.0-20190830223141-573d9926052a + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gotest.tools v2.2.0+incompatible // indirect gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52 mvdan.cc/sh v2.6.4+incompatible ) diff --git a/go.sum b/go.sum index 187a2285de7d3cd67e20fc74c5b0377b5b0e560b..906f9c38cb955e94ef5e56445b685b9959199d7f 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,27 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.6.0 h1:66qjqZk8kalYAvDRtM1AdAJQI0tj4Wrue3Eq3B3pmFU= github.com/fatih/color v1.6.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -20,8 +29,13 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-colorable v0.0.9 h1:UVL0vNpWh04HeJXV0KLcaT7r06gOH2l4OW6ddYRUIY4= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.3 h1:ns/ykhmWi7G9O+8a448SecJU3nSMBXJfqQkl0upE1jI= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0 h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -29,15 +43,25 @@ github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.0.5 h1:8c8b5uO0zS4X6RPl/sd1ENwSkIc0/H2PaHxE3udaE8I= github.com/sirupsen/logrus v1.0.5/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4= github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16 h1:3gGa1bM0nG7Ruhu5b7wKnoOOwAD/fJ8iyyAcpOzDG3A= go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16/go.mod h1:iKV5yK9t+J5nG9O3uF6KYdPEz3dyfMyB15MN1rbQ8Qw= go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= @@ -53,14 +77,20 @@ golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20181102091132-c10e9556a7bc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 h1:FP8hkuE6yUEaJnK7O2eTuejKWwW+Rhfj80dQ2JcKxCU= -golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190830142957-1e83adbbebd0 h1:7z820YPX9pxWR59qM7BE5+fglp4D/mKqAwCvGt11b+8= +golang.org/x/sys v0.0.0-20190830142957-1e83adbbebd0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -69,14 +99,16 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZe golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190429184909-35c670923e21 h1:Kjcw+D2LTzLmxOHrMK9uvYP/NigJ0EdwMgzt6EU+Ghs= -golang.org/x/tools v0.0.0-20190429184909-35c670923e21/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/tools v0.0.0-20190830223141-573d9926052a h1:XAHT1kdPpnU8Hk+FPi42KZFhtNFEk4vBg1U4OmIeHTU= +golang.org/x/tools v0.0.0-20190830223141-573d9926052a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/airbrake/gobrake.v2 v2.0.9 h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2 h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0= @@ -85,8 +117,12 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.1.0+incompatible h1:5USw7CrJBYKqjg9R7QlA6jzqZKEAtvW82aNmsxxGPxw= gotest.tools v2.1.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52 h1:Qr31uPFyjpOhAgRfKV4ATUnknnLT2X7HFjqwkstdbbE= gotest.tools/gotestsum v0.3.6-0.20190825182939-fc6cb5870c52/go.mod h1:Mnf3e5FUzXbkCfynWBGOwLssY7gTQgCHObK9tMpAriY= mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM= diff --git a/websocket.go b/websocket.go index 6f28a4bff89a6711a435d42f23cd38847531ab98..7dabfa25f686df3de346a2dcddb86dc527bdd315 100644 --- a/websocket.go +++ b/websocket.go @@ -886,17 +886,17 @@ func init() { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - err := c.ping(ctx) + id := rand.Uint64() + p := strconv.FormatUint(id, 10) + + err := c.ping(ctx, p) if err != nil { return xerrors.Errorf("failed to ping: %w", err) } return nil } -func (c *Conn) ping(ctx context.Context) error { - id := rand.Uint64() - p := strconv.FormatUint(id, 10) - +func (c *Conn) ping(ctx context.Context, p string) error { pong := make(chan struct{}) c.activePingsMu.Lock() diff --git a/websocket_autobahn_python_test.go b/websocket_autobahn_python_test.go new file mode 100644 index 0000000000000000000000000000000000000000..32ee1f5ceeeaf89377ddffdfc1b3a34b234bc493 --- /dev/null +++ b/websocket_autobahn_python_test.go @@ -0,0 +1,239 @@ +// +build autobahn-python + +package websocket_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "strconv" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +// https://github.com/crossbario/autobahn-python/tree/master/wstest +func TestPythonAutobahnServer(t *testing.T) { + t.Parallel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) + if err != nil { + t.Logf("server handshake failed: %+v", err) + return + } + echoLoop(r.Context(), c) + })) + defer s.Close() + + spec := map[string]interface{}{ + "outdir": "ci/out/wstestServerReports", + "servers": []interface{}{ + map[string]interface{}{ + "agent": "main", + "url": strings.Replace(s.URL, "http", "ws", 1), + }, + }, + "cases": []string{"*"}, + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. 7.5.1 is the same. + // 12.* and 13.* as we do not support compression. + "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + wstest := exec.CommandContext(ctx, "wstest", args...) + out, err := wstest.CombinedOutput() + if err != nil { + t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + } + + checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") +} + +func unusedListenAddr() (string, error) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py +func TestPythonAutobahnClientOld(t *testing.T) { + t.Parallel() + + serverAddr, err := unusedListenAddr() + if err != nil { + t.Fatalf("failed to get unused listen addr for wstest: %v", err) + } + + wsServerURL := "ws://" + serverAddr + + spec := map[string]interface{}{ + "url": wsServerURL, + "outdir": "ci/out/wstestClientReports", + "cases": []string{"*"}, + // See TestAutobahnServer for the reasons why we exclude these. + "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), + // Disables some server that runs as part of fuzzingserver mode. + // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 + "--webport=0", + } + wstest := exec.CommandContext(ctx, "wstest", args...) + err = wstest.Start() + if err != nil { + t.Fatal(err) + } + defer func() { + err := wstest.Process.Kill() + if err != nil { + t.Error(err) + } + }() + + // Let it come up. + time.Sleep(time.Second * 5) + + var cases int + func() { + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + _, r, err := c.Reader(ctx) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + cases, err = strconv.Atoi(string(b)) + if err != nil { + t.Fatal(err) + } + + c.Close(websocket.StatusNormalClosure, "") + }() + + for i := 1; i <= cases; i++ { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*45) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) + if err != nil { + t.Fatal(err) + } + echoLoop(ctx, c) + }() + } + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) + if err != nil { + t.Fatal(err) + } + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func checkWSTestIndex(t *testing.T, path string) { + wstestOut, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("failed to read index.json: %v", err) + } + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + BehaviorClose string `json:"behaviorClose"` + } + err = json.Unmarshal(wstestOut, &indexJSON) + if err != nil { + t.Fatalf("failed to unmarshal index.json: %v", err) + } + + var failed bool + for _, tests := range indexJSON { + for test, result := range tests { + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + failed = true + t.Errorf("test %v failed", test) + } + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + failed = true + t.Errorf("bad close behaviour for test %v", test) + } + } + } + + if failed { + path = strings.Replace(path, ".json", ".html", 1) + if os.Getenv("CI") == "" { + t.Errorf("wstest found failure, please see %q (output as an artifact in CI)", path) + } + } +} diff --git a/websocket_bench_test.go b/websocket_bench_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6a54fab21c0dbf3106db21a48e13a3d51d5ba61a --- /dev/null +++ b/websocket_bench_test.go @@ -0,0 +1,145 @@ +package websocket_test + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func BenchmarkConn(b *testing.B) { + sizes := []int{ + 2, + 16, + 32, + 512, + 4096, + 16384, + } + + b.Run("write", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + b.Run("stream", func(b *testing.B) { + benchConn(b, false, true, size) + }) + b.Run("buffer", func(b *testing.B) { + benchConn(b, false, false, size) + }) + }) + } + }) + + b.Run("echo", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + benchConn(b, true, true, size) + }) + } + }) +} + +func benchConn(b *testing.B, echo, stream bool, size int) { + s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, nil) + if err != nil { + return err + } + if echo { + echoLoop(r.Context(), c) + } else { + discardLoop(r.Context(), c) + } + return nil + }, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + b.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + msg := []byte(strings.Repeat("2", size)) + readBuf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if stream { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } + + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) + } + + err = w.Close() + if err != nil { + b.Fatal(err) + } + } else { + err = c.Write(ctx, websocket.MessageText, msg) + if err != nil { + b.Fatal(err) + } + } + + if echo { + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err) + } + + _, err = io.ReadFull(r, readBuf) + if err != nil { + b.Fatal(err) + } + } + } + b.StopTimer() + + c.Close(websocket.StatusNormalClosure, "") +} + +func discardLoop(ctx context.Context, c *websocket.Conn) { + defer c.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32768) + echo := func() error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + + _, err = io.CopyBuffer(ioutil.Discard, r, b) + if err != nil { + return err + } + return nil + } + + for { + err := echo() + if err != nil { + return + } + } +} diff --git a/websocket_test.go b/websocket_test.go index e6529f3bb94e860efa14ecebddc65a8240dbf1e5..3482cbde8348d993f5850773b1a1e8f9aa201c29 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -1,19 +1,18 @@ package websocket_test import ( + "bytes" "context" + "encoding/binary" "encoding/json" "fmt" "io" "io/ioutil" "math/rand" - "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" - "os" - "os/exec" "reflect" "strconv" "strings" @@ -175,7 +174,7 @@ func TestHandshake(t *testing.T) { wsURL := strings.Replace(s.URL, "http", "ws", 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := tc.client(ctx, wsURL) @@ -598,7 +597,7 @@ func TestConn(t *testing.T) { { name: "largeControlFrame", server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OPClose, []byte(strings.Repeat("x", 4096))) + _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte(strings.Repeat("x", 4096))) if err != nil { return err } @@ -613,7 +612,7 @@ func TestConn(t *testing.T) { { name: "fragmentedControlFrame", server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OPPing, []byte(strings.Repeat("x", 32))) + _, err := c.WriteFrame(ctx, false, websocket.OpPing, []byte(strings.Repeat("x", 32))) if err != nil { return err } @@ -632,7 +631,7 @@ func TestConn(t *testing.T) { { name: "invalidClosePayload", server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OPClose, []byte{0x17, 0x70}) + _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{0x17, 0x70}) if err != nil { return err } @@ -736,7 +735,7 @@ func TestConn(t *testing.T) { if err != nil { return xerrors.Errorf("failed to flush: %w", err) } - _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) if err != nil { return xerrors.Errorf("expected non nil error") } @@ -751,7 +750,7 @@ func TestConn(t *testing.T) { return assertErrorContains(err, "received continuation frame not after data") }, client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OPContinuation, []byte(strings.Repeat("x", 10))) + _, err := c.WriteFrame(ctx, false, websocket.OpContinuation, []byte(strings.Repeat("x", 10))) return err }, }, @@ -810,7 +809,7 @@ func TestConn(t *testing.T) { if err != nil { return xerrors.Errorf("failed to flush: %w", err) } - _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) if err != nil { return xerrors.Errorf("expected non nil error") } @@ -844,7 +843,11 @@ func TestConn(t *testing.T) { }, client: func(ctx context.Context, c *websocket.Conn) error { _, err := c.WriteHalfFrame(ctx) - return err + if err != nil { + return err + } + c.CloseUnderlyingConn() + return nil }, }, } @@ -871,7 +874,7 @@ func TestConn(t *testing.T) { wsURL := strings.Replace(s.URL, "http", "ws", 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() opts := tc.dialOpts @@ -917,7 +920,7 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e atomic.AddInt64(&conns, 1) defer atomic.AddInt64(&conns, -1) - ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) defer cancel() r = r.WithContext(ctx) @@ -946,406 +949,956 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e } } -// https://github.com/crossbario/autobahn-python/tree/master/wstest -func TestAutobahnServer(t *testing.T) { - t.Parallel() - if os.Getenv("AUTOBAHN") == "" { - t.Skip("Set $AUTOBAHN to run the autobahn test suite.") - } - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - t.Logf("server handshake failed: %+v", err) - return - } - echoLoop(r.Context(), c) - })) - defer s.Close() - - spec := map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - // 12.* and 13.* as we do not support compression. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} - wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() - if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) - } - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") -} - -func echoLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 40) - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32768) - echo := func() error { - typ, r, err := c.Reader(ctx) - if err != nil { - return err - } - - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - - _, err = io.CopyBuffer(w, r, b) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - - return nil - } - - for { - err := echo() - if err != nil { - return - } - } -} - -func discardLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32768) - echo := func() error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - - _, err = io.CopyBuffer(ioutil.Discard, r, b) - if err != nil { - return err - } - return nil - } - - for { - err := echo() - if err != nil { - return - } - } -} - -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err - } - l.Close() - return l.Addr().String(), nil -} - -// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py -func TestAutobahnClient(t *testing.T) { +func TestAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN") == "" { - t.Skip("Set $AUTOBAHN to run the autobahn test suite.") - } - - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) - } - - wsServerURL := "ws://" + serverAddr - - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), - // Disables some server that runs as part of fuzzingserver mode. - // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 - "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) - err = wstest.Start() - if err != nil { - t.Fatal(err) - } - defer func() { - err := wstest.Process.Kill() - if err != nil { - t.Error(err) - } - }() + run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) { + run2 := func(t *testing.T, testingClient bool) { + // Run random tests over TLS. + tls := rand.Intn(2) == 1 - // Let it come up. - time.Sleep(time.Second * 5) + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + c.SetReadLimit(1 << 40) - var cases int - func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") + ctx := r.Context() + if testingClient { + echoLoop(r.Context(), c) + return nil + } - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) - } - cases, err = strconv.Atoi(string(b)) - if err != nil { - t.Fatal(err) - } + err = fn(ctx, c) + if err != nil { + return err + } + c.Close(websocket.StatusNormalClosure, "") + return nil + }, tls) + defer closeFn() - c.Close(websocket.StatusNormalClosure, "") - }() + wsURL := strings.Replace(s.URL, "http", "ws", 1) - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) + opts := &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + } + if tls { + opts.HTTPClient = s.Client() + } + + c, _, err := websocket.Dial(ctx, wsURL, opts) if err != nil { t.Fatal(err) } - echoLoop(ctx, c) - }() - } - - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) - if err != nil { - t.Fatal(err) - } - c.Close(websocket.StatusNormalClosure, "") - - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") -} - -func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } - - var indexJSON map[string]map[string]struct { - Behavior string `json:"behavior"` - BehaviorClose string `json:"behaviorClose"` - } - err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } + defer c.Close(websocket.StatusInternalError, "") + c.SetReadLimit(1 << 40) - var failed bool - for _, tests := range indexJSON { - for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) + if testingClient { + err = fn(ctx, c) + if err != nil { + t.Fatalf("client failed: %+v", err) + } + c.Close(websocket.StatusNormalClosure, "") + return } - } - } - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, please see %q (output as an artifact in CI)", path) + echoLoop(ctx, c) } + t.Run(name, func(t *testing.T) { + t.Parallel() + + t.Run("server", func(t *testing.T) { + t.Parallel() + run2(t, false) + }) + t.Run("client", func(t *testing.T) { + t.Parallel() + run2(t, true) + }) + }) } -} -func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err + // Section 1. + t.Run("echo", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 0, + 125, + 126, + 127, + 128, + 65535, + 65536, + 65536, } - if echo { - echoLoop(r.Context(), c) - } else { - discardLoop(r.Context(), c) + run := func(typ websocket.MessageType) { + for i, l := range lengths { + l := l + run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error { + p := randBytes(l) + if i == len(lengths)-1 { + w, err := c.Writer(ctx, typ) + if err != nil { + return err + } + for i := 0; i < l; { + j := i + 997 + if j > l { + j = l + } + _, err = w.Write(p[i:j]) + if err != nil { + return err + } + + i = j + } + + err = w.Close() + if err != nil { + return err + } + } else { + err := c.Write(ctx, typ, p) + if err != nil { + return err + } + } + actTyp, p2, err := c.Read(ctx) + if err != nil { + return err + } + err = assertEqualf(typ, actTyp, "unexpected message type") + if err != nil { + return err + } + return assertEqualf(p, p2, "unexpected message") + }) + } } - return nil - }, false) - defer closeFn() - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - - c, _, err := websocket.Dial(ctx, wsURL, nil) - if err != nil { - b.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") + run(websocket.MessageText) + run(websocket.MessageBinary) + }) - msg := []byte(strings.Repeat("2", size)) - readBuf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } + // Section 2. + t.Run("pingPong", func(t *testing.T) { + t.Parallel() - _, err = w.Write(msg) + run(t, "emptyPayload", func(ctx context.Context, c *websocket.Conn) error { + ctx = c.CloseRead(ctx) + return c.PingWithPayload(ctx, "") + }) + run(t, "smallTextPayload", func(ctx context.Context, c *websocket.Conn) error { + ctx = c.CloseRead(ctx) + return c.PingWithPayload(ctx, "hi") + }) + run(t, "smallBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { + ctx = c.CloseRead(ctx) + p := bytes.Repeat([]byte{0xFE}, 16) + return c.PingWithPayload(ctx, string(p)) + }) + run(t, "largeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { + ctx = c.CloseRead(ctx) + p := bytes.Repeat([]byte{0xFE}, 125) + return c.PingWithPayload(ctx, string(p)) + }) + run(t, "tooLargeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { + c.CloseRead(ctx) + p := bytes.Repeat([]byte{0xFE}, 126) + err := c.PingWithPayload(ctx, string(p)) + return assertCloseStatus(err, websocket.StatusProtocolError) + }) + run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error { + err := assertStreamPing(ctx, c, 125) if err != nil { - b.Fatal(err) + return err } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + t.Run("unsolicitedPong", func(t *testing.T) { + t.Parallel() - err = w.Close() - if err != nil { - b.Fatal(err) + var testCases = []struct { + name string + pongPayload string + ping bool + }{ + { + name: "noPayload", + pongPayload: "", + }, + { + name: "payload", + pongPayload: "hi", + }, + { + name: "pongThenPing", + pongPayload: "hi", + ping: true, + }, } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpPong, []byte(tc.pongPayload)) + if err != nil { + return err + } + if tc.ping { + _, err := c.WriteFrame(ctx, true, websocket.OpPing, []byte("meow")) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpPong, []byte("meow")) + if err != nil { + return err + } + } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) } - } + }) + run(t, "tenPings", func(ctx context.Context, c *websocket.Conn) error { + ctx = c.CloseRead(ctx) - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) + for i := 0; i < 10; i++ { + err := c.Ping(ctx) + if err != nil { + return err + } } - _, err = io.ReadFull(r, readBuf) + _, err := c.WriteClose(ctx, websocket.StatusNormalClosure, "") if err != nil { - b.Fatal(err) + return err } - } - } - b.StopTimer() - - c.Close(websocket.StatusNormalClosure, "") -} + <-ctx.Done() -func BenchmarkConn(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } + err = c.Ping(context.Background()) + return assertCloseStatus(err, websocket.StatusNormalClosure) + }) + run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error { + for i := 0; i < 10; i++ { + err := assertStreamPing(ctx, c, 125) + if err != nil { + return err + } + } - b.Run("write", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("stream", func(b *testing.B) { - benchConn(b, false, true, size) - }) - b.Run("buffer", func(b *testing.B) { - benchConn(b, false, false, size) - }) - }) - } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) }) - b.Run("echo", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) - }) + // Section 3. + // We skip the per octet sending as it will add too much complexity. + t.Run("reserved", func(t *testing.T) { + t.Parallel() + + var testCases = []struct { + name string + header websocket.Header + }{ + { + name: "rsv1", + header: websocket.Header{ + Fin: true, + Rsv1: true, + OpCode: websocket.OpClose, + PayloadLength: 0, + }, + }, + { + name: "rsv2", + header: websocket.Header{ + Fin: true, + Rsv2: true, + OpCode: websocket.OpPong, + PayloadLength: 0, + }, + }, + { + name: "rsv3", + header: websocket.Header{ + Fin: true, + Rsv3: true, + OpCode: websocket.OpBinary, + PayloadLength: 0, + }, + }, + { + name: "rsvAll", + header: websocket.Header{ + Fin: true, + Rsv1: true, + Rsv2: true, + Rsv3: true, + OpCode: websocket.OpText, + PayloadLength: 0, + }, + }, + } + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + err := assertEcho(ctx, c, websocket.MessageText, 4096) + if err != nil { + return err + } + err = c.WriteHeader(ctx, tc.header) + if err != nil { + return err + } + err = c.Flush() + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + }) + + // Section 4. + t.Run("opcodes", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + opcode websocket.OpCode + payload bool + echo bool + ping bool + }{ + // Section 1. + { + name: "3", + opcode: 3, + }, + { + name: "4", + opcode: 4, + payload: true, + }, + { + name: "5", + opcode: 5, + echo: true, + ping: true, + }, + { + name: "6", + opcode: 6, + payload: true, + echo: true, + ping: true, + }, + { + name: "7", + opcode: 7, + payload: true, + echo: true, + ping: true, + }, + + // Section 2. + { + name: "11", + opcode: 11, + }, + { + name: "12", + opcode: 12, + payload: true, + }, + { + name: "13", + opcode: 13, + payload: true, + echo: true, + ping: true, + }, + { + name: "14", + opcode: 14, + payload: true, + echo: true, + ping: true, + }, + { + name: "15", + opcode: 15, + payload: true, + echo: true, + ping: true, + }, + } + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + if tc.echo { + err := assertEcho(ctx, c, websocket.MessageText, 4096) + if err != nil { + return err + } + } + + p := []byte(nil) + if tc.payload { + p = randBytes(rand.Intn(4096) + 1) + } + _, err := c.WriteFrame(ctx, true, tc.opcode, p) + if err != nil { + return err + } + if tc.ping { + _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) + if err != nil { + return err + } + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + }) + + // Section 5. + t.Run("fragmentation", func(t *testing.T) { + t.Parallel() + + // 5.1 to 5.8 + testCases := []struct { + name string + opcode websocket.OpCode + success bool + pingInBetween bool + }{ + { + name: "ping", + opcode: websocket.OpPing, + success: false, + }, + { + name: "pong", + opcode: websocket.OpPong, + success: false, + }, + { + name: "text", + opcode: websocket.OpText, + success: true, + }, + { + name: "textPing", + opcode: websocket.OpText, + success: true, + pingInBetween: true, + }, + } + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(16) + _, err := c.WriteFrame(ctx, false, tc.opcode, p1) + if err != nil { + return err + } + err = c.BW().Flush() + if err != nil { + return err + } + if !tc.success { + _, _, err = c.Read(ctx) + return assertCloseStatus(err, websocket.StatusProtocolError) + } + + if tc.pingInBetween { + _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) + if err != nil { + return err + } + } + + p2 := randBytes(16) + _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2) + if err != nil { + return err + } + + err = assertReadFrame(ctx, c, tc.opcode, p1) + if err != nil { + return err + } + + if tc.pingInBetween { + err = assertReadFrame(ctx, c, websocket.OpPong, p1) + if err != nil { + return err + } + } + + return assertReadFrame(ctx, c, websocket.OpContinuation, p2) + }) + } + + t.Run("unexpectedContinuation", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + fin bool + textFirst bool + }{ + { + name: "fin", + fin: true, + }, + { + name: "noFin", + fin: false, + }, + { + name: "echoFirst", + fin: false, + textFirst: true, + }, + // The rest of the tests in this section get complicated and do not inspire much confidence. + } + + for _, tc := range testCases { + tc := tc + run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { + if tc.textFirst { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + return err + } + p1 := randBytes(32) + _, err = w.Write(p1) + if err != nil { + return err + } + p2 := randBytes(32) + _, err = w.Write(p2) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) + if err != nil { + return err + } + } + + _, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32)) + if err != nil { + return err + } + err = c.BW().Flush() + if err != nil { + return err + } + + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + + run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(32) + _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32)) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + + run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(32) + p2 := randBytes(32) + p3 := randBytes(32) + p4 := randBytes(32) + p5 := randBytes(32) + + _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2) + if err != nil { + return err + } + + _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) + if err != nil { + return err + } + + time.Sleep(time.Second) + + _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3) + if err != nil { + return err + } + _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4) + if err != nil { + return err + } + + _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) + if err != nil { + return err + } + + _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5) + if err != nil { + return err + } + + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpPong, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p3) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p4) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpPong, p1) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, p5) + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) + if err != nil { + return err + } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + }) + }) + + // Section 7 + t.Run("closeHandling", func(t *testing.T) { + t.Parallel() + + // 1.1 - 1.4 is useless. + run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error { + p1 := randBytes(32) + _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) + if err != nil { + return err + } + err = c.Flush() + if err != nil { + return err + } + _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") + if err != nil { + return err + } + err = assertReadFrame(ctx, c, websocket.OpText, p1) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) + }) + + run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error { + // 262144 bytes. + p1 := randBytes(1 << 18) + err := c.Write(ctx, websocket.MessageText, p1) + if err != nil { + return err + } + _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") + if err != nil { + return err + } + err = assertReadMessage(ctx, c, websocket.MessageText, p1) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) + }) + + run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpClose, nil) + if err != nil { + return err + } + return assertReadFrame(ctx, c, websocket.OpClose, []byte{}) + }) + + run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1}) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + + run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + + run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, randString(16)) + }) + + run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, randString(123)) + }) + + run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, websocket.OpClose, + append([]byte{0x03, 0xE8}, randString(124)...), + ) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + + t.Run("validCloses", func(t *testing.T) { + t.Parallel() + + codes := [...]websocket.StatusCode{ + 1000, + 1001, + 1002, + 1003, + 1007, + 1008, + 1009, + 1010, + 1011, + 3000, + 3999, + 4000, + 4999, + } + for _, code := range codes { + run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { + return assertCloseHandshake(ctx, c, code, randString(32)) + }) + } + }) + + t.Run("invalidCloseCodes", func(t *testing.T) { + t.Parallel() + + codes := []websocket.StatusCode{ + 0, + 999, + 1004, + 1005, + 1006, + 1016, + 1100, + 2000, + 2999, + 5000, + 65535, + } + for _, code := range codes { + run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { + p := make([]byte, 2) + binary.BigEndian.PutUint16(p, uint16(code)) + p = append(p, randBytes(32)...) + _, err := c.WriteFrame(ctx, true, websocket.OpClose, p) + if err != nil { + return err + } + return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) + }) + } + }) + }) + + // Section 9. + t.Run("limits", func(t *testing.T) { + t.Parallel() + + t.Run("unfragmentedEcho", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 1 << 16, // 65536 + 1 << 18, // 262144 + // Anything higher is completely unnecessary. + } + + for _, l := range lengths { + l := l + run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { + return assertEcho(ctx, c, websocket.MessageBinary, l) + }) + } + }) + + t.Run("fragmentedEcho", func(t *testing.T) { + t.Parallel() + + fragments := []int{ + 64, + 256, + 1 << 10, + 1 << 12, + 1 << 14, + 1 << 16, + 1 << 18, + } + + for _, l := range fragments { + fragmentLength := l + run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + return err + } + b := randBytes(1 << 18) + for i := 0; i < len(b); { + j := i + fragmentLength + if j > len(b) { + j = len(b) + } + + _, err = w.Write(b[i:j]) + if err != nil { + return err + } + + i = j + } + err = w.Close() + if err != nil { + return err + } + + err = assertReadMessage(ctx, c, websocket.MessageText, b) + if err != nil { + return err + } + return assertCloseHandshake(ctx, c, websocket.StatusNormalClosure, "") + }) + } + }) + + t.Run("latencyEcho", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 0, + 16, + 64, + } + + for _, l := range lengths { + l := l + run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { + for i := 0; i < 1000; i++ { + err := assertEcho(ctx, c, websocket.MessageBinary, l) + if err != nil { + return err + } + } + return nil + }) + } + }) + }) +} + +func echoLoop(ctx context.Context, c *websocket.Conn) { + defer c.Close(websocket.StatusInternalError, "") + + c.SetReadLimit(1 << 40) + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32768) + echo := func() error { + typ, r, err := c.Reader(ctx) + if err != nil { + return err + } + + w, err := c.Writer(ctx, typ) + if err != nil { + return err + } + + _, err = io.CopyBuffer(w, r, b) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + + return nil + } + + for { + err := echo() + if err != nil { + return } - }) -} - -func assertCloseStatus(err error, code websocket.StatusCode) error { - var cerr websocket.CloseError - if !xerrors.As(err, &cerr) { - return xerrors.Errorf("no websocket close error in error chain: %+v", err) + } +} + +func assertCloseStatus(err error, code websocket.StatusCode) error { + var cerr websocket.CloseError + if !xerrors.As(err, &cerr) { + return xerrors.Errorf("no websocket close error in error chain: %+v", err) } return assertEqualf(code, cerr.Code, "unexpected status code") } @@ -1360,6 +1913,31 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) (er return assertEqualf(exp, act, "unexpected JSON") } +func randBytes(n int) []byte { + return make([]byte, n) +} + +func randString(n int) string { + return string(randBytes(n)) +} + +func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) error { + p := randBytes(n) + err := c.Write(ctx, typ, p) + if err != nil { + return err + } + typ2, p2, err := c.Read(ctx) + if err != nil { + return err + } + err = assertEqualf(typ, typ2, "unexpected data type") + if err != nil { + return err + } + return assertEqualf(p, p2, "unexpected payload") +} + func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { expType := reflect.TypeOf(exp) actv := reflect.New(expType.Elem()) @@ -1378,7 +1956,7 @@ func assertSubprotocol(c *websocket.Conn, exp string) error { func assertEqualf(exp, act interface{}, f string, v ...interface{}) error { if diff := cmpDiff(exp, act); diff != "" { - return xerrors.Errorf(f+": %v", append(v, diff)) + return xerrors.Errorf(f+": %v", append(v, diff)...) } return nil } @@ -1405,3 +1983,73 @@ func assertErrorIs(exp, act error) error { } return nil } + +func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.OpCode, p []byte) error { + actOpcode, actP, err := c.ReadFrame(ctx) + if err != nil { + return err + } + err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) + if err != nil { + return err + } + return assertEqualf(p, actP, "unexpected frame %v payload", opcode) +} + +func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error { + actOpcode, actP, err := c.ReadFrame(ctx) + if err != nil { + return err + } + err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) + if err != nil { + return err + } + ce, err := websocket.ParseClosePayload(actP) + if err != nil { + return xerrors.Errorf("failed to parse close frame payload: %w", err) + } + return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) +} + +func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error { + p, err := c.WriteClose(ctx, code, reason) + if err != nil { + return err + } + return assertReadFrame(ctx, c, websocket.OpClose, p) +} + +func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error { + err := c.WriteHeader(ctx, websocket.Header{ + Fin: true, + OpCode: websocket.OpPing, + PayloadLength: int64(l), + }) + if err != nil { + return err + } + for i := 0; i < l; i++ { + err = c.BW().WriteByte(0xFE) + if err != nil { + return err + } + err = c.BW().Flush() + if err != nil { + return err + } + } + return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l)) +} + +func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error { + actTyp, actP, err := c.Read(ctx) + if err != nil { + return err + } + err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) + if err != nil { + return err + } + return assertEqualf(p, actP, "unexpected frame %v payload", actTyp) +}