From 50179241fe3edeeae757c609a64429cc4c0abd14 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Thu, 10 Oct 2019 11:55:26 -0400
Subject: [PATCH] Cleanup CloseStatus/CloseError docs and improve wasm test
 script

---
 README.md                   |  2 +-
 ci/test.mk                  |  3 ++-
 ci/wasmtest.sh              |  4 +--
 conn.go                     | 18 ++++++++++---
 conn_test.go                | 25 ++++---------------
 doc.go                      |  2 +-
 example_test.go             |  2 +-
 internal/wsgrace/wsgrace.go | 50 +++++++++++++++++++++++++++++++++++++
 internal/wsjstest/main.go   | 21 +++++++++++++---
 9 files changed, 94 insertions(+), 33 deletions(-)
 create mode 100644 internal/wsgrace/wsgrace.go

diff --git a/README.md b/README.md
index e7fea3a..9dd5d0a 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ For a production quality example that shows off the full API, see the [echo exam
 
 Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError).
 There is also [websocket.CloseStatus](https://godoc.org/nhooyr.io/websocket#CloseStatus) to quickly grab the close status code out of a [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError).
-See the [CloseError godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseError).
+See the [CloseStatus godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseStatus).
 
 ### Server
 
diff --git a/ci/test.mk b/ci/test.mk
index 0fe0ce1..b86abb7 100644
--- a/ci/test.mk
+++ b/ci/test.mk
@@ -12,11 +12,12 @@ codecov: _gotest
 	curl -s https://codecov.io/bash | bash -s -- -Z -f ci/out/coverage.prof
 
 _gotest:
-	echo "--- gotest" && go test -parallel=32 -coverprofile=ci/out/coverage.prof -coverpkg=./... ./...
+	echo "--- gotest" && go test -parallel=32 -coverprofile=ci/out/coverage.prof -coverpkg=./... $$TESTFLAGS ./...
 	sed -i '/_stringer\.go/d' ci/out/coverage.prof
 	sed -i '/wsjstest\/main\.go/d' ci/out/coverage.prof
 	sed -i '/wsecho\.go/d' ci/out/coverage.prof
 	sed -i '/assert\.go/d' ci/out/coverage.prof
+	sed -i '/wsgrace\.go/d' ci/out/coverage.prof
 
 gotest-wasm: wsjstest
 	echo "--- wsjstest" && ./ci/wasmtest.sh
diff --git a/ci/wasmtest.sh b/ci/wasmtest.sh
index 586efec..f285fdf 100755
--- a/ci/wasmtest.sh
+++ b/ci/wasmtest.sh
@@ -5,14 +5,14 @@ set -euo pipefail
 wsjstestOut="$(mktemp -d)/wsjstestOut"
 mkfifo "$wsjstestOut"
 timeout 45s wsjstest > "$wsjstestOut" &
-wsjstestPID="$!"
 
 WS_ECHO_SERVER_URL="$(head -n 1 "$wsjstestOut")"
 export WS_ECHO_SERVER_URL
 
 GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./...
 
-if ! wait "$wsjstestPID" ; then
+kill %%
+if ! wait %% ; then
   echo "wsjstest exited unsuccessfully"
   exit 1
 fi
diff --git a/conn.go b/conn.go
index b7b9360..43a9439 100644
--- a/conn.go
+++ b/conn.go
@@ -175,9 +175,14 @@ func (c *Conn) timeoutLoop() {
 
 		case <-readCtx.Done():
 			c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
-			// Guaranteed to eventually close the connection since it will not try and read
-			// but only write.
-			go c.exportedClose(StatusPolicyViolation, "read timed out", false)
+			// Guaranteed to eventually close the connection since we can only ever send
+			// one close frame.
+			go func() {
+				c.exportedClose(StatusPolicyViolation, "read timed out", true)
+				// Ensure the connection closes, i.e if we already sent a close frame and timed out
+				// to read the peer's close frame.
+				c.close(nil)
+			}()
 			readCtx = context.Background()
 		case <-writeCtx.Done():
 			c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
@@ -339,6 +344,13 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 
 		err = fmt.Errorf("received close: %w", ce)
 		c.writeClose(b, err, false)
+
+		if ctx.Err() != nil {
+			// The above close probably has been returned by the peer in response
+			// to our read timing out so we have to return the read timed out error instead.
+			return fmt.Errorf("read timed out: %w", ctx.Err())
+		}
+
 		return err
 	default:
 		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
diff --git a/conn_test.go b/conn_test.go
index 8dcff94..1acdf59 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -22,7 +22,6 @@ import (
 	"reflect"
 	"strconv"
 	"strings"
-	"sync/atomic"
 	"testing"
 	"time"
 
@@ -34,6 +33,7 @@ import (
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket/internal/assert"
 	"nhooyr.io/websocket/internal/wsecho"
+	"nhooyr.io/websocket/internal/wsgrace"
 	"nhooyr.io/websocket/wsjson"
 	"nhooyr.io/websocket/wspb"
 )
@@ -927,16 +927,7 @@ func TestConn(t *testing.T) {
 }
 
 func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) {
-	var conns int64
 	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		atomic.AddInt64(&conns, 1)
-		defer atomic.AddInt64(&conns, -1)
-
-		ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
-		defer cancel()
-
-		r = r.WithContext(ctx)
-
 		err := fn(w, r)
 		if err != nil {
 			tb.Errorf("server failed: %+v", err)
@@ -947,18 +938,12 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e
 	} else {
 		s = httptest.NewServer(h)
 	}
+	closeFn2 := wsgrace.Grace(s.Config)
 	return s, func() {
-		s.Close()
-
-		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-		defer cancel()
-
-		for atomic.LoadInt64(&conns) > 0 {
-			if ctx.Err() != nil {
-				tb.Fatalf("waiting for server to come down timed out: %v", ctx.Err())
-			}
+		err := closeFn2()
+		if err != nil {
+			tb.Fatal(err)
 		}
-
 	}
 }
 
diff --git a/doc.go b/doc.go
index 1610eed..b29d2cd 100644
--- a/doc.go
+++ b/doc.go
@@ -17,7 +17,7 @@
 //
 // Use the errors.As function new in Go 1.13 to check for websocket.CloseError.
 // Or use the CloseStatus function to grab the StatusCode out of a websocket.CloseError
-// See the CloseError example.
+// See the CloseStatus example.
 //
 // Wasm
 //
diff --git a/example_test.go b/example_test.go
index 1cb3d79..bc603af 100644
--- a/example_test.go
+++ b/example_test.go
@@ -64,7 +64,7 @@ func ExampleDial() {
 
 // This example dials a server and then expects to be disconnected with status code
 // websocket.StatusNormalClosure.
-func ExampleCloseError() {
+func ExampleCloseStatus() {
 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
 	defer cancel()
 
diff --git a/internal/wsgrace/wsgrace.go b/internal/wsgrace/wsgrace.go
new file mode 100644
index 0000000..513af1f
--- /dev/null
+++ b/internal/wsgrace/wsgrace.go
@@ -0,0 +1,50 @@
+package wsgrace
+
+import (
+	"context"
+	"fmt"
+	"net/http"
+	"sync/atomic"
+	"time"
+)
+
+// Grace wraps s.Handler to gracefully shutdown WebSocket connections.
+// The returned function must be used to close the server instead of s.Close.
+func Grace(s *http.Server) (closeFn func() error) {
+	h := s.Handler
+	var conns int64
+	s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		atomic.AddInt64(&conns, 1)
+		defer atomic.AddInt64(&conns, -1)
+
+		ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
+		defer cancel()
+
+		r = r.WithContext(ctx)
+
+		h.ServeHTTP(w, r)
+	})
+
+	return func() error {
+		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+		defer cancel()
+
+		err := s.Shutdown(ctx)
+		if err != nil {
+			return fmt.Errorf("server shutdown failed: %v", err)
+		}
+
+		t := time.NewTicker(time.Millisecond * 10)
+		defer t.Stop()
+		for {
+			select {
+			case <-t.C:
+				if atomic.LoadInt64(&conns) == 0 {
+					return nil
+				}
+			case <-ctx.Done():
+				return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err())
+			}
+		}
+	}
+}
diff --git a/internal/wsjstest/main.go b/internal/wsjstest/main.go
index b8b1cba..96eee2c 100644
--- a/internal/wsjstest/main.go
+++ b/internal/wsjstest/main.go
@@ -8,14 +8,18 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"os"
-	"runtime"
+	"os/signal"
 	"strings"
+	"syscall"
 
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket/internal/wsecho"
+	"nhooyr.io/websocket/internal/wsgrace"
 )
 
 func main() {
+	log.SetPrefix("wsecho")
+
 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 			Subprotocols:       []string{"echo"},
@@ -30,11 +34,20 @@ func main() {
 		if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
 			log.Fatalf("unexpected echo loop error: %+v", err)
 		}
-
-		os.Exit(0)
 	}))
+	closeFn := wsgrace.Grace(s.Config)
+	defer func() {
+		err := closeFn()
+		if err != nil {
+			log.Fatal(err)
+		}
+	}()
 
 	wsURL := strings.Replace(s.URL, "http", "ws", 1)
 	fmt.Printf("%v\n", wsURL)
-	runtime.Goexit()
+
+	sigs := make(chan os.Signal, 1)
+	signal.Notify(sigs, syscall.SIGTERM)
+
+	<-sigs
 }
-- 
GitLab