From e335b09210e47739545fe30c69f3a0f56ede98a0 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Wed, 26 Feb 2020 14:47:40 -0500
Subject: [PATCH] Use grace in chat example

---
 accept.go              |  4 ++--
 chat-example/go.mod    |  4 +++-
 chat-example/go.sum    | 10 ++++++++--
 chat-example/index.css |  2 +-
 chat-example/index.js  | 13 ++++++++++---
 chat-example/main.go   | 29 +++++++++++++++++++++++++++--
 example_test.go        | 14 +++++++++++---
 grace.go               | 20 ++++++++++++--------
 8 files changed, 74 insertions(+), 22 deletions(-)

diff --git a/accept.go b/accept.go
index 52a9345..dd96c9b 100644
--- a/accept.go
+++ b/accept.go
@@ -76,8 +76,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
 	defer errd.Wrap(&err, "failed to accept WebSocket connection")
 
 	g := graceFromRequest(r)
-	if g != nil && g.isClosing() {
-		err := errors.New("server closing")
+	if g != nil && g.isShuttingdown() {
+		err := errors.New("server shutting down")
 		http.Error(w, err.Error(), http.StatusServiceUnavailable)
 		return nil, err
 	}
diff --git a/chat-example/go.mod b/chat-example/go.mod
index 34fa5a6..c47a5a2 100644
--- a/chat-example/go.mod
+++ b/chat-example/go.mod
@@ -2,4 +2,6 @@ module nhooyr.io/websocket/example-chat
 
 go 1.13
 
-require nhooyr.io/websocket v1.8.2
+require nhooyr.io/websocket v0.0.0
+
+replace nhooyr.io/websocket => ../
diff --git a/chat-example/go.sum b/chat-example/go.sum
index 0755fca..e4bbd62 100644
--- a/chat-example/go.sum
+++ b/chat-example/go.sum
@@ -1,12 +1,18 @@
+github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
 github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
+github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
 github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
+github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
 github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
+github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I=
 github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
 github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y=
 github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
 golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA=
-nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ=
diff --git a/chat-example/index.css b/chat-example/index.css
index 2980466..73a8e0f 100644
--- a/chat-example/index.css
+++ b/chat-example/index.css
@@ -5,7 +5,7 @@ body {
 
 #root {
   padding: 40px 20px;
-  max-width: 480px;
+  max-width: 600px;
   margin: auto;
   height: 100vh;
 
diff --git a/chat-example/index.js b/chat-example/index.js
index 8fb3dfb..a42c2d3 100644
--- a/chat-example/index.js
+++ b/chat-example/index.js
@@ -7,8 +7,11 @@
     const conn = new WebSocket(`ws://${location.host}/subscribe`)
 
     conn.addEventListener("close", ev => {
-      console.info("websocket disconnected, reconnecting in 1000ms", ev)
-      setTimeout(dial, 1000)
+      appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true)
+      if (ev.code !== 1001) {
+        appendLog("Reconnecting in 1s", true)
+        setTimeout(dial, 1000)
+      }
     })
     conn.addEventListener("open", ev => {
       console.info("websocket connected")
@@ -34,10 +37,14 @@
   const messageInput = document.getElementById("message-input")
 
   // appendLog appends the passed text to messageLog.
-  function appendLog(text) {
+  function appendLog(text, error) {
     const p = document.createElement("p")
     // Adding a timestamp to each message makes the log easier to read.
     p.innerText = `${new Date().toLocaleTimeString()}: ${text}`
+    if (error) {
+      p.style.color = "red"
+      p.style.fontStyle = "bold"
+    }
     messageLog.append(p)
     return p
   }
diff --git a/chat-example/main.go b/chat-example/main.go
index 2a52092..f985d38 100644
--- a/chat-example/main.go
+++ b/chat-example/main.go
@@ -1,12 +1,16 @@
 package main
 
 import (
+	"context"
 	"errors"
 	"log"
 	"net"
 	"net/http"
 	"os"
+	"os/signal"
 	"time"
+
+	"nhooyr.io/websocket"
 )
 
 func main() {
@@ -38,10 +42,31 @@ func run() error {
 	m.HandleFunc("/subscribe", ws.subscribeHandler)
 	m.HandleFunc("/publish", ws.publishHandler)
 
+	var g websocket.Grace
 	s := http.Server{
-		Handler:      m,
+		Handler:      g.Handler(m),
 		ReadTimeout:  time.Second * 10,
 		WriteTimeout: time.Second * 10,
 	}
-	return s.Serve(l)
+	errc := make(chan error, 1)
+	go func() {
+		errc <- s.Serve(l)
+	}()
+
+	sigs := make(chan os.Signal, 1)
+	signal.Notify(sigs, os.Interrupt)
+	select {
+	case err := <-errc:
+		log.Printf("failed to serve: %v", err)
+	case sig := <-sigs:
+		log.Printf("terminating: %v", sig)
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+	defer cancel()
+
+	s.Shutdown(ctx)
+	g.Shutdown(ctx)
+
+	return nil
 }
diff --git a/example_test.go b/example_test.go
index ce049bc..462de37 100644
--- a/example_test.go
+++ b/example_test.go
@@ -167,12 +167,20 @@ func ExampleGrace() {
 		ReadTimeout:  time.Second * 15,
 		WriteTimeout: time.Second * 15,
 	}
-	go s.ListenAndServe()
+
+	errc := make(chan error, 1)
+	go func() {
+		errc <- s.ListenAndServe()
+	}()
 
 	sigs := make(chan os.Signal, 1)
 	signal.Notify(sigs, os.Interrupt)
-	sig := <-sigs
-	log.Printf("recieved %v, shutting down", sig)
+	select {
+	case err := <-errc:
+		log.Printf("failed to listen and serve: %v", err)
+	case sig := <-sigs:
+		log.Printf("terminating: %v", sig)
+	}
 
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
 	defer cancel()
diff --git a/grace.go b/grace.go
index 8dadc43..c53cd40 100644
--- a/grace.go
+++ b/grace.go
@@ -15,10 +15,13 @@ import (
 // and then use Close or Shutdown to gracefully close these connections.
 //
 // Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
+// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket
+// connections.
 type Grace struct {
-	mu      sync.Mutex
-	closing bool
-	conns   map[*Conn]struct{}
+	mu           sync.Mutex
+	closed       bool
+	shuttingDown bool
+	conns        map[*Conn]struct{}
 }
 
 // Handler returns a handler that wraps around h to record
@@ -33,10 +36,10 @@ func (g *Grace) Handler(h http.Handler) http.Handler {
 	})
 }
 
-func (g *Grace) isClosing() bool {
+func (g *Grace) isShuttingdown() bool {
 	g.mu.Lock()
 	defer g.mu.Unlock()
-	return g.closing
+	return g.shuttingDown
 }
 
 func graceFromRequest(r *http.Request) *Grace {
@@ -47,7 +50,7 @@ func graceFromRequest(r *http.Request) *Grace {
 func (g *Grace) addConn(c *Conn) error {
 	g.mu.Lock()
 	defer g.mu.Unlock()
-	if g.closing {
+	if g.closed {
 		c.Close(StatusGoingAway, "server shutting down")
 		return errors.New("server shutting down")
 	}
@@ -72,7 +75,8 @@ type gracefulContextKey struct{}
 // connections with StatusGoingAway.
 func (g *Grace) Close() error {
 	g.mu.Lock()
-	g.closing = true
+	g.shuttingDown = true
+	g.closed = true
 	var wg sync.WaitGroup
 	for c := range g.conns {
 		wg.Add(1)
@@ -97,7 +101,7 @@ func (g *Grace) Shutdown(ctx context.Context) error {
 	defer g.Close()
 
 	g.mu.Lock()
-	g.closing = true
+	g.shuttingDown = true
 	g.mu.Unlock()
 
 	// Same poll period used by net/http.
-- 
GitLab