From 0236290a81e00191840a77024f19d5cefcb05638 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sat, 20 Apr 2019 14:12:56 -0400 Subject: [PATCH] Refactor library to use option structs and add better docs/benchmarks --- .github/main.workflow | 12 ++- .github/test/entrypoint.sh | 17 ---- .gitignore | 1 + README.md | 14 +-- accept.go | 84 +++++++---------- bench_test.go | 139 ++++++++++++++++------------- ci/bench/Dockerfile | 10 +++ ci/bench/entrypoint.sh | 15 ++++ {.github => ci}/fmt/Dockerfile | 0 {.github => ci}/fmt/entrypoint.sh | 2 +- {.github => ci}/lib.sh | 0 {.github => ci}/lint/Dockerfile | 0 {.github => ci}/lint/entrypoint.sh | 2 +- ci/run.sh | 40 +++++++++ {.github => ci}/test/Dockerfile | 0 ci/test/entrypoint.sh | 25 ++++++ dial.go | 75 +++++----------- {.github => docs}/CONTRIBUTING.md | 2 +- example_test.go | 12 +-- export_test.go | 18 ++++ go.mod | 2 +- go.sum | 4 +- json.go | 16 +++- messagetype.go | 7 +- statuscode.go | 3 +- test.sh | 28 ------ websocket.go | 50 ++++------- websocket_test.go | 109 +++++++++++++++------- 28 files changed, 384 insertions(+), 303 deletions(-) delete mode 100755 .github/test/entrypoint.sh create mode 100644 ci/bench/Dockerfile create mode 100755 ci/bench/entrypoint.sh rename {.github => ci}/fmt/Dockerfile (100%) rename {.github => ci}/fmt/entrypoint.sh (94%) rename {.github => ci}/lib.sh (100%) rename {.github => ci}/lint/Dockerfile (100%) rename {.github => ci}/lint/entrypoint.sh (84%) create mode 100755 ci/run.sh rename {.github => ci}/test/Dockerfile (100%) create mode 100755 ci/test/entrypoint.sh rename {.github => docs}/CONTRIBUTING.md (85%) create mode 100644 export_test.go delete mode 100755 test.sh diff --git a/.github/main.workflow b/.github/main.workflow index c4947b0..d87db4f 100644 --- a/.github/main.workflow +++ b/.github/main.workflow @@ -1,17 +1,21 @@ workflow "main" { on = "push" - resolves = ["fmt", "lint", "test"] + resolves = ["fmt", "lint", "test", "bench"] } action "lint" { - uses = "./.github/lint" + uses = "../test/lint" } action "fmt" { - uses = "./.github/fmt" + uses = "../test/fmt" } action "test" { - uses = "./.github/test" + uses = "../test/test" secrets = ["CODECOV_TOKEN"] } + +action "bench" { + uses = "../test/bench" +} diff --git a/.github/test/entrypoint.sh b/.github/test/entrypoint.sh deleted file mode 100755 index 3090348..0000000 --- a/.github/test/entrypoint.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash - -source .github/lib.sh || exit 1 - -COVERAGE_PROFILE=$(mktemp) -go test -race "-coverprofile=${COVERAGE_PROFILE}" -vet=off ./... -go tool cover "-func=${COVERAGE_PROFILE}" - -if [[ $CI ]]; then - bash <(curl -s https://codecov.io/bash) -f "$COVERAGE_PROFILE" -else - go tool cover "-html=${COVERAGE_PROFILE}" -o=coverage.html - - set +x - echo - echo "please open coverage.html to see detailed test coverage stats" -fi diff --git a/.gitignore b/.gitignore index 35ecb6b..9d6b49c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ coverage.html wstest_reports websocket.test +profs diff --git a/README.md b/README.md index 975aefa..256d192 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,12 @@ websocket is a minimal and idiomatic WebSocket library for Go. -This library is in heavy development. +At minimum Go 1.12 is required as websocket uses a new [feature](https://github.com/golang/go/issues/26937#issuecomment-415855861) in net/http +to perform WebSocket handshakes. + +This library is not final and the API is subject to change. + +If you have any feedback, please feel free to open an issue. ## Install @@ -15,8 +20,8 @@ go get nhooyr.io/websocket ## Features - Full support of the WebSocket protocol -- Only depends on stdlib -- Simple to use +- Zero dependencies outside of the stdlib +- Very minimal and carefully considered API - context.Context is first class - net/http is used for WebSocket dials and upgrades - Thoroughly tested, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) @@ -26,9 +31,6 @@ go get nhooyr.io/websocket - [ ] WebSockets over HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) - [ ] Deflate extension support [#5](https://github.com/nhooyr/websocket/issues/5) -- [ ] More optimization [#11](https://github.com/nhooyr/websocket/issues/11) -- [ ] WASM [#15](https://github.com/nhooyr/websocket/issues/15) -- [ ] Ping/pongs [#1](https://github.com/nhooyr/websocket/issues/1) ## Example diff --git a/accept.go b/accept.go index e0c31ef..2a7ab95 100644 --- a/accept.go +++ b/accept.go @@ -12,43 +12,31 @@ import ( "golang.org/x/xerrors" ) -// AcceptOption is an option that can be passed to Accept. -// The implementations of this interface are printable. -type AcceptOption interface { - acceptOption() -} - -type acceptSubprotocols []string - -func (o acceptSubprotocols) acceptOption() {} - -// AcceptSubprotocols lists the websocket subprotocols that Accept will negotiate with a client. -// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to -// reject it, close the connection if c.Subprotocol() == "". -func AcceptSubprotocols(protocols ...string) AcceptOption { - return acceptSubprotocols(protocols) -} - -type acceptInsecureOrigin struct{} - -func (o acceptInsecureOrigin) acceptOption() {} - -// AcceptInsecureOrigin disables Accept's origin verification -// behaviour. By default Accept only allows the handshake to -// succeed if the javascript that is initiating the handshake -// is on the same domain as the server. This is to prevent CSRF -// when secure data is stored in cookies. -// -// See https://stackoverflow.com/a/37837709/4283659 -// -// Use this if you want a WebSocket server any javascript can -// connect to or you want to perform Origin verification yourself -// and allow some whitelist of domains. -// -// Ensure you understand exactly what the above means before you use -// this option in conjugation with cookies containing secure data. -func AcceptInsecureOrigin() AcceptOption { - return acceptInsecureOrigin{} +// AcceptOptions represents the options available to pass to Accept. +type AcceptOptions struct { + // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. + // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to + // reject it, close the connection if c.Subprotocol() == "". + Subprotocols []string + + // InsecureSkipVerify disables Accept's origin verification + // behaviour. By default Accept only allows the handshake to + // succeed if the javascript that is initiating the handshake + // is on the same domain as the server. This is to prevent CSRF + // when secure data is stored in a cookie as there is no same + // origin policy for WebSockets. In other words, javascript from + // any domain can perform a WebSocket dial on an arbitrary server. + // This dial will include cookies which means the arbitrary javascript + // can perform actions as the authenticated user. + // + // See https://stackoverflow.com/a/37837709/4283659 + // + // The only time you need this is if your javascript is running on a different domain + // than your WebSocket server. + // Please think carefully about whether you really need this option before you use it. + // If you do, remember if you store secure data in cookies, you wil need to verify the + // Origin header. + InsecureSkipVerify bool } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { @@ -88,26 +76,14 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to WebSocket. // Accept will reject the handshake if the Origin is not the same as the Host unless -// the AcceptInsecureOrigin option is passed. -// Accept uses w to write the handshake response so the timeouts on the http.Server apply. -func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) { - var subprotocols []string - verifyOrigin := true - for _, opt := range opts { - switch opt := opt.(type) { - case acceptInsecureOrigin: - verifyOrigin = false - case acceptSubprotocols: - subprotocols = []string(opt) - } - } - +// the InsecureSkipVerify option is set. +func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { err := verifyClientRequest(w, r) if err != nil { return nil, err } - if verifyOrigin { + if !opts.InsecureSkipVerify { err = authenticateOrigin(r) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) @@ -127,7 +103,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn handleKey(w, r) - subproto := selectSubprotocol(r, subprotocols) + subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { w.Header().Set("Sec-WebSocket-Protocol", subproto) } @@ -190,5 +166,5 @@ func authenticateOrigin(r *http.Request) error { if strings.EqualFold(u.Host, r.Host) { return nil } - return xerrors.Errorf("request origin %q is not authorized", origin) + return xerrors.Errorf("request origin %q is not authorized for host %v", origin, r.Host) } diff --git a/bench_test.go b/bench_test.go index 66331e0..62c435d 100644 --- a/bench_test.go +++ b/bench_test.go @@ -12,76 +12,95 @@ import ( "nhooyr.io/websocket" ) -func BenchmarkConn(b *testing.B) { - b.StopTimer() +func benchConn(b *testing.B, stream bool) { + name := "buffered" + if stream { + name = "stream" + } - s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, - websocket.AcceptSubprotocols("echo"), - ) - if err != nil { - b.Logf("server handshake failed: %+v", err) - return - } - echoLoop(r.Context(), c) - })) - defer closeFn() + b.Run(name, func(b *testing.B) { + s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + b.Logf("server handshake failed: %+v", err) + return + } + if stream { + streamEchoLoop(r.Context(), c) + } else { + bufferedEchoLoop(r.Context(), c) + } - wsURL := strings.Replace(s.URL, "http", "ws", 1) + })) + defer closeFn() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() + wsURL := strings.Replace(s.URL, "http", "ws", 1) - c, _, err := websocket.Dial(ctx, wsURL) - if err != nil { - b.Fatalf("failed to dial: %v", err) - } - defer c.Close(websocket.StatusInternalError, "") + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() - runN := func(n int) { - b.Run(strconv.Itoa(n), func(b *testing.B) { - msg := []byte(strings.Repeat("2", n)) - buf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ResetTimer() - for i := 0; i < b.N; i++ { - w, err := c.Write(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } + c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{}) + if err != nil { + b.Fatalf("failed to dial: %v", err) + } + defer c.Close(websocket.StatusInternalError, "") - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } + runN := func(n int) { + b.Run(strconv.Itoa(n), func(b *testing.B) { + msg := []byte(strings.Repeat("2", n)) + buf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if stream { + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } - err = w.Close() - if err != nil { - b.Fatal(err) - } + _, err = w.Write(msg) + if err != nil { + b.Fatal(err) + } - _, r, err := c.Read(ctx) - if err != nil { - b.Fatal(err, b.N) - } + err = w.Close() + if err != nil { + b.Fatal(err) + } + } else { + err = c.Write(ctx, websocket.MessageText, msg) + if err != nil { + b.Fatal(err) + } + } + _, r, err := c.Reader(ctx) + if err != nil { + b.Fatal(err, b.N) + } - _, err = io.ReadFull(r, buf) - if err != nil { - b.Fatal(err) + _, err = io.ReadFull(r, buf) + if err != nil { + b.Fatal(err) + } } - } - b.StopTimer() - }) - } + b.StopTimer() + }) + } + + runN(32) + runN(128) + runN(512) + runN(1024) + runN(4096) + runN(16384) + runN(65536) + runN(131072) - runN(32) - runN(128) - runN(512) - runN(1024) - runN(4096) - runN(16384) - runN(65536) - runN(131072) + c.Close(websocket.StatusNormalClosure, "") + }) +} - c.Close(websocket.StatusNormalClosure, "") +func BenchmarkConn(b *testing.B) { + benchConn(b, true) + benchConn(b, false) } diff --git a/ci/bench/Dockerfile b/ci/bench/Dockerfile new file mode 100644 index 0000000..d5ab4a6 --- /dev/null +++ b/ci/bench/Dockerfile @@ -0,0 +1,10 @@ +FROM golang:1.12 + +LABEL "com.github.actions.name"="bench" +LABEL "com.github.actions.description"="bench" +LABEL "com.github.actions.icon"="code" +LABEL "com.github.actions.color"="purple" + +COPY entrypoint.sh /entrypoint.sh + +CMD ["/entrypoint.sh"] diff --git a/ci/bench/entrypoint.sh b/ci/bench/entrypoint.sh new file mode 100755 index 0000000..19edfe9 --- /dev/null +++ b/ci/bench/entrypoint.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +source ci/lib.sh || exit 1 + +mkdir -p profs +go test --vet=off --run=^$ -bench=. \ + ./... +# -cpuprofile=profs/cpu \ +# -memprofile=profs/mem \ +# -blockprofile=profs/block \ +# -mutexprofile=profs/mutex \ + +set +x +echo "profiles are in ./profs +keep in mind that every profiler Go provides is enabled so that may skew the benchmarks" diff --git a/.github/fmt/Dockerfile b/ci/fmt/Dockerfile similarity index 100% rename from .github/fmt/Dockerfile rename to ci/fmt/Dockerfile diff --git a/.github/fmt/entrypoint.sh b/ci/fmt/entrypoint.sh similarity index 94% rename from .github/fmt/entrypoint.sh rename to ci/fmt/entrypoint.sh index 579fc80..14cecc8 100755 --- a/.github/fmt/entrypoint.sh +++ b/ci/fmt/entrypoint.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -source .github/lib.sh || exit 1 +source ci/lib.sh || exit 1 gen() { # Unfortunately, this is the only way to ensure go.mod and go.sum are correct. diff --git a/.github/lib.sh b/ci/lib.sh similarity index 100% rename from .github/lib.sh rename to ci/lib.sh diff --git a/.github/lint/Dockerfile b/ci/lint/Dockerfile similarity index 100% rename from .github/lint/Dockerfile rename to ci/lint/Dockerfile diff --git a/.github/lint/entrypoint.sh b/ci/lint/entrypoint.sh similarity index 84% rename from .github/lint/entrypoint.sh rename to ci/lint/entrypoint.sh index c81f1f1..c539495 100755 --- a/.github/lint/entrypoint.sh +++ b/ci/lint/entrypoint.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -source .github/lib.sh || exit 1 +source ci/lib.sh || exit 1 ( shopt -s globstar nullglob dotglob diff --git a/ci/run.sh b/ci/run.sh new file mode 100755 index 0000000..66bb16a --- /dev/null +++ b/ci/run.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# This script is for local testing. See .github for CI. + +cd "$(dirname "${0}")/.." || exit 1 +source ci/lib.sh || exit 1 + +function docker_run() { + local DIR="$1" + local IMAGE + IMAGE="$(docker build -q "$DIR")" + docker run \ + -it \ + -v "${PWD}:/repo" \ + -v "$(go env GOPATH):/go" \ + -v "$(go env GOCACHE):/root/.cache/go-build" \ + -w /repo \ + "${IMAGE}" +} + +# Use this to analyze benchmark profiles. +if [[ ${1-} == "analyze" ]]; then + docker run \ + -it \ + -v "${PWD}:/repo" \ + -v "$(go env GOPATH):/go" \ + -v "$(go env GOCACHE):/root/.cache/go-build" \ + -w /repo \ + golang:1.12 +fi + +if [[ $# -gt 0 ]]; then + docker_run "ci/$*" + exit 0 +fi + +docker_run ci/fmt +docker_run ci/lint +docker_run ci/test +docker_run ci/bench diff --git a/.github/test/Dockerfile b/ci/test/Dockerfile similarity index 100% rename from .github/test/Dockerfile rename to ci/test/Dockerfile diff --git a/ci/test/entrypoint.sh b/ci/test/entrypoint.sh new file mode 100755 index 0000000..6cb11b0 --- /dev/null +++ b/ci/test/entrypoint.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +source ci/lib.sh || exit 1 + +mkdir -p profs + +set +x +echo "this step includes benchmarks for race detection and coverage purposes +but the numbers will be misleading. please see the bench step for more +accurate numbers" +set -x + +go test -race -coverprofile=profs/coverage --vet=off -bench=. ./... +go tool cover -func=profs/coverage + +if [[ $CI ]]; then + bash <(curl -s https://codecov.io/bash) -f profs/coverage +else + go tool cover -html=profs/coverage -o=coverage.html + + set +x + echo + echo "please open coverage.html to see detailed test coverage stats" + echo "profiles are in ./prof/" +fi diff --git a/dial.go b/dial.go index 99e3c06..d790990 100644 --- a/dial.go +++ b/dial.go @@ -14,40 +14,19 @@ import ( "golang.org/x/xerrors" ) -// DialOption represents a dial option that can be passed to Dial. -// The implementations are printable for easy debugging. -type DialOption interface { - dialOption() -} - -type dialHTTPClient http.Client - -func (o dialHTTPClient) dialOption() {} - -// DialHTTPClient is the http client used for the handshake. -// Its Transport must use HTTP/1.1 and must return writable bodies -// for WebSocket handshakes. -// http.Transport does this correctly. -func DialHTTPClient(hc *http.Client) DialOption { - return (*dialHTTPClient)(hc) -} - -type dialHeader http.Header - -func (o dialHeader) dialOption() {} - -// DialHeader are the HTTP headers included in the handshake request. -func DialHeader(h http.Header) DialOption { - return dialHeader(h) -} - -type dialSubprotocols []string - -func (o dialSubprotocols) dialOption() {} - -// DialSubprotocols accepts a slice of protcols to include in the Sec-WebSocket-Protocol header. -func DialSubprotocols(subprotocols ...string) DialOption { - return dialSubprotocols(subprotocols) +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // HTTPClient is the http client used for the handshake. + // Its Transport must use HTTP/1.1 and must return writable bodies + // for WebSocket handshakes. This was introduced in Go 1.12. + // http.Transport does this correctly. + HTTPClient *http.Client + + // Header specifies the HTTP headers included in the handshake request. + Header http.Header + + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string } // We use this key for all client requests as the Sec-WebSocket-Key header is useless. @@ -56,19 +35,12 @@ func DialSubprotocols(subprotocols ...string) DialOption { var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16)) // Dial performs a WebSocket handshake on the given url with the given options. -func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.Response, err error) { - httpClient := http.DefaultClient - var subprotocols []string - header := http.Header{} - for _, o := range opts { - switch o := o.(type) { - case dialSubprotocols: - subprotocols = o - case dialHeader: - header = http.Header(o) - case *dialHTTPClient: - httpClient = (*http.Client)(o) - } +func Dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) { + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } + if opts.Header == nil { + opts.Header = http.Header{} } parsedURL, err := url.Parse(u) @@ -87,16 +59,16 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R req, _ := http.NewRequest("GET", parsedURL.String(), nil) req = req.WithContext(ctx) - req.Header = header + req.Header = opts.Header req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) - if len(subprotocols) > 0 { - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(subprotocols, ",")) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - resp, err := httpClient.Do(req) + resp, err := opts.HTTPClient.Do(req) if err != nil { return nil, nil, err } @@ -121,6 +93,7 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R return nil, resp, xerrors.Errorf("websocket: body is not a read write closer but should be: %T", rwc) } + // TODO pool bufio c := &Conn{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), br: bufio.NewReader(rwc), diff --git a/.github/CONTRIBUTING.md b/docs/CONTRIBUTING.md similarity index 85% rename from .github/CONTRIBUTING.md rename to docs/CONTRIBUTING.md index 180b765..b4ac44b 100644 --- a/.github/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -13,4 +13,4 @@ Be sure to link to an existing issue if one exists. In general, try creating an before making a PR to get some discussion going and to make sure you do not spend time on a PR that may be rejected. -Run `test.sh` to test your changes. You only need docker and bash to run tests. +Run `test/run.sh` to test your changes. You only need docker and bash to run the tests. diff --git a/example_test.go b/example_test.go index 16ffba3..30434e6 100644 --- a/example_test.go +++ b/example_test.go @@ -14,7 +14,9 @@ import ( func ExampleAccept_echo() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("echo")) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) if err != nil { log.Printf("server handshake failed: %v", err) return @@ -30,13 +32,13 @@ func ExampleAccept_echo() { ctx, cancel := context.WithTimeout(r.Context(), time.Minute) defer cancel() - typ, r, err := c.Read(ctx) + typ, r, err := c.Reader(ctx) if err != nil { return err } r = io.LimitReader(r, 32768) - w, err := c.Write(ctx, typ) + w, err := c.Writer(ctx, typ) if err != nil { return err } @@ -74,7 +76,7 @@ func ExampleAccept_echo() { func ExampleAccept() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { log.Printf("server handshake failed: %v", err) return @@ -111,7 +113,7 @@ func ExampleDial() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - c, _, err := websocket.Dial(ctx, "ws://localhost:8080") + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) if err != nil { log.Fatalf("failed to ws dial: %v", err) } diff --git a/export_test.go b/export_test.go new file mode 100644 index 0000000..e4fdddb --- /dev/null +++ b/export_test.go @@ -0,0 +1,18 @@ +package websocket + +import ( + "context" +) + +// Write writes p as a single data frame to the connection. This is an optimization +// method for when the entire message is in memory and does not need to be streamed +// to the peer via Writer. +// +// Both paths are zero allocation but Writer always has +// to write an additional fin frame when Close is called on the writer which +// can result in worse performance if the full message exceeds the buffer size +// which is 4096 right now as then two syscalls will be necessary to complete the message. +// TODO this is no good as we cannot write daata frame msg in between other ones +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + return c.writeControl(ctx, opcode(typ), p) +} diff --git a/go.mod b/go.mod index 7f99b02..928137e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3 golang.org/x/net v0.0.0-20190311183353-d8887717615a golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - golang.org/x/tools v0.0.0-20190329215204-73054e8977d1 + golang.org/x/tools v0.0.0-20190419195823-c39e7748f6eb golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 mvdan.cc/sh v2.6.4+incompatible ) diff --git a/go.sum b/go.sum index d2c667d..0e10a2c 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190329215204-73054e8977d1 h1:rLRH2E2wN5JjGJSVlBe1ioUkCKgb6eoL9X8bDmtEpsk= -golang.org/x/tools v0.0.0-20190329215204-73054e8977d1/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190419195823-c39e7748f6eb h1:JbWwiXQ1L1jWKTGSwj6y63WT+bESGWOhXY8xoAs0yoo= +golang.org/x/tools v0.0.0-20190419195823-c39e7748f6eb/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 h1:1AGvnywFL1aB5KLRxyLseWJI6aSYPo3oF7HSpXdWQdU= golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM= diff --git a/json.go b/json.go index 0d85a5d..b40baf6 100644 --- a/json.go +++ b/json.go @@ -10,7 +10,17 @@ import ( // JSONConn wraps around a Conn with JSON helpers. type JSONConn struct { - *Conn + Conn *Conn +} + +// Subprotocol calls Subprotocol on the underlying Conn. +func (jc JSONConn) Subprotocol() string { + return jc.Conn.Subprotocol() +} + +// Close calls Close on the underlying Conn. +func (jc JSONConn) Close(code StatusCode, reason string) error { + return jc.Conn.Close(code, reason) } // Read reads a json message into v. @@ -23,7 +33,7 @@ func (jc JSONConn) Read(ctx context.Context, v interface{}) error { } func (jc JSONConn) read(ctx context.Context, v interface{}) error { - typ, r, err := jc.Conn.Read(ctx) + typ, r, err := jc.Conn.Reader(ctx) if err != nil { return err } @@ -53,7 +63,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w, err := jc.Conn.Write(ctx, MessageText) + w, err := jc.Conn.Writer(ctx, MessageText) if err != nil { return xerrors.Errorf("failed to get message writer: %w", err) } diff --git a/messagetype.go b/messagetype.go index 54276b3..1fd9cd6 100644 --- a/messagetype.go +++ b/messagetype.go @@ -1,12 +1,15 @@ package websocket -// MessageType represents the Opcode of a WebSocket data frame. +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int //go:generate go run golang.org/x/tools/cmd/stringer -type=MessageType // MessageType constants. const ( - MessageText MessageType = MessageType(opText) + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = MessageType(opText) + // MessageBinary is for binary messages like Protobufs. MessageBinary MessageType = MessageType(opBinary) ) diff --git a/statuscode.go b/statuscode.go index d742195..db5b751 100644 --- a/statuscode.go +++ b/statuscode.go @@ -10,6 +10,7 @@ import ( ) // StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int //go:generate go run golang.org/x/tools/cmd/stringer -type=StatusCode @@ -42,7 +43,7 @@ const ( ) // CloseError represents an error from a WebSocket close frame. -// Methods on the Conn will only return this for a non normal close code. +// It is returned by a Conn's method when the Connection was closed with a WebSocket close frame. type CloseError struct { Code StatusCode Reason string diff --git a/test.sh b/test.sh deleted file mode 100755 index d6e8e00..0000000 --- a/test.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash - -# This script is for local testing. See .github for CI. - -source .github/lib.sh || exit 1 -cd "$(dirname "${0}")" - -function docker_run() { - local DIR="$1" - local IMAGE - IMAGE="$(docker build -q "$DIR")" - docker run \ - -it \ - -v "${PWD}:/repo" \ - -v "$(go env GOPATH):/go" \ - -v "$(go env GOCACHE):/root/.cache/go-build" \ - -w /repo \ - "${IMAGE}" -} - -if [[ $# -gt 0 ]]; then - docker_run ".github/$*" - exit 0 -fi - -docker_run .github/fmt -docker_run .github/lint -docker_run .github/test diff --git a/websocket.go b/websocket.go index 7992351..4d717ac 100644 --- a/websocket.go +++ b/websocket.go @@ -19,8 +19,9 @@ type control struct { } // Conn represents a WebSocket connection. -// Pings will always be automatically responded to with pongs, you do not -// have to do anything special. +// All methods except Reader can be used concurrently. +// Please be sure to call Close on the connection when you +// are finished with it to release resources. type Conn struct { subprotocol string br *bufio.Reader @@ -116,7 +117,7 @@ func (c *Conn) writeFrame(h header, p []byte) { return } - if h.opcode.controlOp() { + if h.fin { err := c.bw.Flush() if err != nil { c.close(xerrors.Errorf("failed to write to connection: %w", err)) @@ -134,7 +135,6 @@ messageLoop: select { case <-c.closed: return - case dataType = <-c.write: case control := <-c.control: h := header{ fin: true, @@ -147,8 +147,9 @@ messageLoop: case <-c.closed: return case c.writeDone <- struct{}{}: + continue } - continue + case dataType = <-c.write: } var firstSent bool @@ -205,18 +206,6 @@ messageLoop: c.writeFrame(h, nil) - select { - case <-c.closed: - return - case c.writeDone <- struct{}{}: - } - - err := c.bw.Flush() - if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return - } - continue messageLoop } } @@ -369,6 +358,7 @@ func (c *Conn) writePong(p []byte) error { // Close closes the WebSocket connection with the given status code and reason. // It will write a WebSocket close frame with a timeout of 5 seconds. +// Concurrent calls to Close are ok. func (c *Conn) Close(code StatusCode, reason string) error { ce := CloseError{ Code: code, @@ -435,17 +425,17 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } } -// Write returns a writer bounded by the context that will write +// Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // Ensure you close the writer once you have written the entire message. -// Concurrent calls to Write are ok. -func (c *Conn) Write(ctx context.Context, dataType MessageType) (io.WriteCloser, error) { +// Concurrent calls to Writer are ok. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { select { case <-c.closed: return nil, c.closeErr case <-ctx.Done(): return nil, ctx.Err() - case c.write <- dataType: + case c.write <- typ: return messageWriter{ ctx: ctx, c: c, @@ -460,9 +450,6 @@ type messageWriter struct { } // Write writes the given bytes to the WebSocket connection. -// The frame will automatically be fragmented as appropriate -// with the buffers obtained from http.Hijacker. -// Please ensure you call Close once you have written the full message. func (w messageWriter) Write(p []byte) (int, error) { select { case <-w.c.closed: @@ -493,26 +480,19 @@ func (w messageWriter) Close() error { case <-w.ctx.Done(): return w.ctx.Err() case w.c.writeFlush <- struct{}{}: - } - - select { - case <-w.c.closed: - return w.c.closeErr - case <-w.ctx.Done(): - return w.ctx.Err() - case <-w.c.writeDone: return nil } } -// ReadMessage will wait until there is a WebSocket data message to read from the connection. +// Reader will wait until there is a WebSocket data message to read from the connection. // It returns the type of the message and a reader to read it. // The passed context will also bound the reader. // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames and not become stuck waiting for a data message to be read. // Please ensure to read the full message from io.Reader. -// You can only read a single message at a time. -func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) { +// You can only read a single message at a time so do not call this method +// concurrently. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { select { case <-c.closed: diff --git a/websocket_test.go b/websocket_test.go index d6d222d..5655038 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -36,7 +36,9 @@ func TestHandshake(t *testing.T) { { name: "handshake", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("myproto")) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + Subprotocols: []string{"myproto"}, + }) if err != nil { return err } @@ -44,7 +46,9 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, resp, err := websocket.Dial(ctx, u, websocket.DialSubprotocols("myproto")) + c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Subprotocols: []string{"myproto"}, + }) if err != nil { return err } @@ -70,7 +74,7 @@ func TestHandshake(t *testing.T) { { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { return err } @@ -82,7 +86,9 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialSubprotocols("meow")) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Subprotocols: []string{"meow"}, + }) if err != nil { return err } @@ -97,7 +103,9 @@ func TestHandshake(t *testing.T) { { name: "subprotocol", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("echo", "lar")) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + Subprotocols: []string{"echo", "lar"}, + }) if err != nil { return err } @@ -109,7 +117,9 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialSubprotocols("poof", "echo")) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Subprotocols: []string{"poof", "echo"}, + }) if err != nil { return err } @@ -124,7 +134,7 @@ func TestHandshake(t *testing.T) { { name: "badOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err == nil { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected error regarding bad origin") @@ -134,7 +144,9 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", "http://unauthorized.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialHeader(h)) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Header: h, + }) if err == nil { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected handshake failure") @@ -145,7 +157,7 @@ func TestHandshake(t *testing.T) { { name: "acceptSecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin()) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { return err } @@ -154,8 +166,10 @@ func TestHandshake(t *testing.T) { }, client: func(ctx context.Context, u string) error { h := http.Header{} - h.Set("Origin", "https://127.0.0.1") - c, _, err := websocket.Dial(ctx, u, websocket.DialHeader(h)) + h.Set("Origin", u) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Header: h, + }) if err != nil { return err } @@ -166,7 +180,9 @@ func TestHandshake(t *testing.T) { { name: "acceptInsecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin()) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) if err != nil { return err } @@ -176,7 +192,9 @@ func TestHandshake(t *testing.T) { client: func(ctx context.Context, u string) error { h := http.Header{} h.Set("Origin", "https://example.com") - c, _, err := websocket.Dial(ctx, u, websocket.DialHeader(h)) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Header: h, + }) if err != nil { return err } @@ -187,7 +205,7 @@ func TestHandshake(t *testing.T) { { name: "echo", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { return err } @@ -223,7 +241,7 @@ func TestHandshake(t *testing.T) { return nil }, client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) if err != nil { return err } @@ -272,7 +290,7 @@ func TestHandshake(t *testing.T) { if cookie.Value != "myvalue" { return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) } - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { return err } @@ -298,9 +316,9 @@ func TestHandshake(t *testing.T) { hc := &http.Client{ Jar: jar, } - c, _, err := websocket.Dial(ctx, u, - websocket.DialHTTPClient(hc), - ) + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + HTTPClient: hc, + }) if err != nil { return err } @@ -364,14 +382,14 @@ func TestAutobahnServer(t *testing.T) { t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, - websocket.AcceptSubprotocols("echo"), - ) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) if err != nil { t.Logf("server handshake failed: %+v", err) return } - echoLoop(r.Context(), c) + streamEchoLoop(r.Context(), c) })) defer s.Close() @@ -418,7 +436,7 @@ func TestAutobahnServer(t *testing.T) { checkWSTestIndex(t, "./wstest_reports/server/index.json") } -func echoLoop(ctx context.Context, c *websocket.Conn) { +func streamEchoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") ctx, cancel := context.WithTimeout(ctx, time.Minute) @@ -426,12 +444,12 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { b := make([]byte, 32768) echo := func() error { - typ, r, err := c.Read(ctx) + typ, r, err := c.Reader(ctx) if err != nil { return err } - w, err := c.Write(ctx, typ) + w, err := c.Writer(ctx, typ) if err != nil { return err } @@ -457,6 +475,35 @@ func echoLoop(ctx context.Context, c *websocket.Conn) { } } +func bufferedEchoLoop(ctx context.Context, c *websocket.Conn) { + defer c.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 131072+2) + echo := func() error { + typ, r, err := c.Reader(ctx) + if err != nil { + return err + } + + n, err := io.ReadFull(r, b) + if err != io.ErrUnexpectedEOF { + return err + } + + return c.Write(ctx, typ, b[:n]) + } + + for { + err := echo() + if err != nil { + return + } + } +} + // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py func TestAutobahnClient(t *testing.T) { t.Parallel() @@ -510,13 +557,13 @@ func TestAutobahnClient(t *testing.T) { var cases int func() { - c, _, err := websocket.Dial(ctx, "ws://localhost:9001/getCaseCount") + c, _, err := websocket.Dial(ctx, "ws://localhost:9001/getCaseCount", websocket.DialOptions{}) if err != nil { t.Fatalf("failed to dial: %v", err) } defer c.Close(websocket.StatusInternalError, "") - _, r, err := c.Read(ctx) + _, r, err := c.Reader(ctx) if err != nil { t.Fatal(err) } @@ -537,15 +584,15 @@ func TestAutobahnClient(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second*45) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/runCase?case=%v&agent=main", i)) + c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/runCase?case=%v&agent=main", i), websocket.DialOptions{}) if err != nil { t.Fatalf("failed to dial: %v", err) } - echoLoop(ctx, c) + streamEchoLoop(ctx, c) }() } - c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/updateReports?agent=main")) + c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/updateReports?agent=main"), websocket.DialOptions{}) if err != nil { t.Fatalf("failed to dial: %v", err) } -- GitLab