diff --git a/README.md b/README.md index e7fea3aab3f89e4bc6b8933891003963a05816fe..9dd5d0a8a23e2b23fc6cc1b621f6dc46b88aaca0 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ For a production quality example that shows off the full API, see the [echo exam Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). There is also [websocket.CloseStatus](https://godoc.org/nhooyr.io/websocket#CloseStatus) to quickly grab the close status code out of a [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). -See the [CloseError godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseError). +See the [CloseStatus godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseStatus). ### Server diff --git a/ci/test.mk b/ci/test.mk index 0fe0ce19abbe3d9d6fff2cd0d31ddf87e195ae1e..b86abb704d0db1139d28a0fe4514001aaf7f4f03 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -12,11 +12,12 @@ codecov: _gotest curl -s https://codecov.io/bash | bash -s -- -Z -f ci/out/coverage.prof _gotest: - echo "--- gotest" && go test -parallel=32 -coverprofile=ci/out/coverage.prof -coverpkg=./... ./... + echo "--- gotest" && go test -parallel=32 -coverprofile=ci/out/coverage.prof -coverpkg=./... $$TESTFLAGS ./... sed -i '/_stringer\.go/d' ci/out/coverage.prof sed -i '/wsjstest\/main\.go/d' ci/out/coverage.prof sed -i '/wsecho\.go/d' ci/out/coverage.prof sed -i '/assert\.go/d' ci/out/coverage.prof + sed -i '/wsgrace\.go/d' ci/out/coverage.prof gotest-wasm: wsjstest echo "--- wsjstest" && ./ci/wasmtest.sh diff --git a/ci/wasmtest.sh b/ci/wasmtest.sh index 586efec28ef487ba7354aa6f2dcbc7746818e1bc..f285fdf42ba12004f904b7d7bacc9abc8b845e16 100755 --- a/ci/wasmtest.sh +++ b/ci/wasmtest.sh @@ -5,14 +5,14 @@ set -euo pipefail wsjstestOut="$(mktemp -d)/wsjstestOut" mkfifo "$wsjstestOut" timeout 45s wsjstest > "$wsjstestOut" & -wsjstestPID="$!" WS_ECHO_SERVER_URL="$(head -n 1 "$wsjstestOut")" export WS_ECHO_SERVER_URL GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./... -if ! wait "$wsjstestPID" ; then +kill %% +if ! wait %% ; then echo "wsjstest exited unsuccessfully" exit 1 fi diff --git a/conn.go b/conn.go index b7b9360ee9352f3c3a63da60c71b4de7f324598d..43a94397a3caf85a6cf03e0db08af9e38b8e4a7e 100644 --- a/conn.go +++ b/conn.go @@ -175,9 +175,14 @@ func (c *Conn) timeoutLoop() { case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - // Guaranteed to eventually close the connection since it will not try and read - // but only write. - go c.exportedClose(StatusPolicyViolation, "read timed out", false) + // Guaranteed to eventually close the connection since we can only ever send + // one close frame. + go func() { + c.exportedClose(StatusPolicyViolation, "read timed out", true) + // Ensure the connection closes, i.e if we already sent a close frame and timed out + // to read the peer's close frame. + c.close(nil) + }() readCtx = context.Background() case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) @@ -339,6 +344,13 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { err = fmt.Errorf("received close: %w", ce) c.writeClose(b, err, false) + + if ctx.Err() != nil { + // The above close probably has been returned by the peer in response + // to our read timing out so we have to return the read timed out error instead. + return fmt.Errorf("read timed out: %w", ctx.Err()) + } + return err default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) diff --git a/conn_test.go b/conn_test.go index 8dcff944f8662ec187ef87565f84b980f951d615..1acdf5951ec38d8786447702a289a654aa7e44ec 100644 --- a/conn_test.go +++ b/conn_test.go @@ -22,7 +22,6 @@ import ( "reflect" "strconv" "strings" - "sync/atomic" "testing" "time" @@ -34,6 +33,7 @@ import ( "nhooyr.io/websocket" "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/internal/wsecho" + "nhooyr.io/websocket/internal/wsgrace" "nhooyr.io/websocket/wsjson" "nhooyr.io/websocket/wspb" ) @@ -927,16 +927,7 @@ func TestConn(t *testing.T) { } func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { - var conns int64 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) - - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - - r = r.WithContext(ctx) - err := fn(w, r) if err != nil { tb.Errorf("server failed: %+v", err) @@ -947,18 +938,12 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e } else { s = httptest.NewServer(h) } + closeFn2 := wsgrace.Grace(s.Config) return s, func() { - s.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - for atomic.LoadInt64(&conns) > 0 { - if ctx.Err() != nil { - tb.Fatalf("waiting for server to come down timed out: %v", ctx.Err()) - } + err := closeFn2() + if err != nil { + tb.Fatal(err) } - } } diff --git a/doc.go b/doc.go index 1610eed1e33b8bb0264e92f190409a70d196c388..b29d2cdd0cf7442735334323d78696ef687e5b39 100644 --- a/doc.go +++ b/doc.go @@ -17,7 +17,7 @@ // // Use the errors.As function new in Go 1.13 to check for websocket.CloseError. // Or use the CloseStatus function to grab the StatusCode out of a websocket.CloseError -// See the CloseError example. +// See the CloseStatus example. // // Wasm // diff --git a/example_test.go b/example_test.go index 1cb3d799910ddab411f697ae2d009bd72766afc6..bc603aff2a3ef3ce8c933bf85aa79da89559dd94 100644 --- a/example_test.go +++ b/example_test.go @@ -64,7 +64,7 @@ func ExampleDial() { // This example dials a server and then expects to be disconnected with status code // websocket.StatusNormalClosure. -func ExampleCloseError() { +func ExampleCloseStatus() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/internal/wsgrace/wsgrace.go b/internal/wsgrace/wsgrace.go new file mode 100644 index 0000000000000000000000000000000000000000..513af1fe9905eafd12bdfc2c10977e21e5e4fdbc --- /dev/null +++ b/internal/wsgrace/wsgrace.go @@ -0,0 +1,50 @@ +package wsgrace + +import ( + "context" + "fmt" + "net/http" + "sync/atomic" + "time" +) + +// Grace wraps s.Handler to gracefully shutdown WebSocket connections. +// The returned function must be used to close the server instead of s.Close. +func Grace(s *http.Server) (closeFn func() error) { + h := s.Handler + var conns int64 + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&conns, 1) + defer atomic.AddInt64(&conns, -1) + + ctx, cancel := context.WithTimeout(r.Context(), time.Minute) + defer cancel() + + r = r.WithContext(ctx) + + h.ServeHTTP(w, r) + }) + + return func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := s.Shutdown(ctx) + if err != nil { + return fmt.Errorf("server shutdown failed: %v", err) + } + + t := time.NewTicker(time.Millisecond * 10) + defer t.Stop() + for { + select { + case <-t.C: + if atomic.LoadInt64(&conns) == 0 { + return nil + } + case <-ctx.Done(): + return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) + } + } + } +} diff --git a/internal/wsjstest/main.go b/internal/wsjstest/main.go index b8b1cba25b6eb797f02545ca00f53e3dfa5fb3be..96eee2c0543c23a8e72ed9df5201bd34f9574705 100644 --- a/internal/wsjstest/main.go +++ b/internal/wsjstest/main.go @@ -8,14 +8,18 @@ import ( "net/http" "net/http/httptest" "os" - "runtime" + "os/signal" "strings" + "syscall" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/wsecho" + "nhooyr.io/websocket/internal/wsgrace" ) func main() { + log.SetPrefix("wsecho") + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, @@ -30,11 +34,20 @@ func main() { if websocket.CloseStatus(err) != websocket.StatusNormalClosure { log.Fatalf("unexpected echo loop error: %+v", err) } - - os.Exit(0) })) + closeFn := wsgrace.Grace(s.Config) + defer func() { + err := closeFn() + if err != nil { + log.Fatal(err) + } + }() wsURL := strings.Replace(s.URL, "http", "ws", 1) fmt.Printf("%v\n", wsURL) - runtime.Goexit() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + + <-sigs }