diff --git a/accept.go b/accept.go
index 47e20b52c373f6757475c86c30b06626f3dfc745..dd96c9bd96c2fccaebbaf535e4ce53bbe1f8279f 100644
--- a/accept.go
+++ b/accept.go
@@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
 func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
 	defer errd.Wrap(&err, "failed to accept WebSocket connection")
 
+	g := graceFromRequest(r)
+	if g != nil && g.isShuttingdown() {
+		err := errors.New("server shutting down")
+		http.Error(w, err.Error(), http.StatusServiceUnavailable)
+		return nil, err
+	}
+
 	if opts == nil {
 		opts = &AcceptOptions{}
 	}
@@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
 	b, _ := brw.Reader.Peek(brw.Reader.Buffered())
 	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
 
-	return newConn(connConfig{
+	c := newConn(connConfig{
 		subprotocol:    w.Header().Get("Sec-WebSocket-Protocol"),
 		rwc:            netConn,
 		client:         false,
@@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
 
 		br: brw.Reader,
 		bw: brw.Writer,
-	}), nil
+	})
+
+	if g != nil {
+		err = g.addConn(c)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	return c, nil
 }
 
 func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
diff --git a/chat-example/README.md b/chat-example/README.md
index ef06275db3bd1bb6ff2c8c85033e144a2b892965..a4c99a93447f0c4eecacdd48b17a0f285c680de6 100644
--- a/chat-example/README.md
+++ b/chat-example/README.md
@@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin
 
 The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
 `index.html` and then `index.js`.
+
+There are two automated tests for the server included in `chat_test.go`. The first is a simple one
+client echo test. It publishes a single message and ensures it's received.
+
+The second is a complex concurrency test where 10 clients send 128 unique messages
+of max 128 bytes concurrently. The test ensures all messages are seen by every client.
diff --git a/chat-example/chat.go b/chat-example/chat.go
index e6e355d04688bd37f7c2495a5483c794d4453a55..532e50f544fbc39f2ca0a870b0c9cc7ef8ae8e44 100644
--- a/chat-example/chat.go
+++ b/chat-example/chat.go
@@ -3,20 +3,67 @@ package main
 import (
 	"context"
 	"errors"
-	"io"
 	"io/ioutil"
 	"log"
 	"net/http"
 	"sync"
 	"time"
 
+	"golang.org/x/time/rate"
+
 	"nhooyr.io/websocket"
 )
 
 // chatServer enables broadcasting to a set of subscribers.
 type chatServer struct {
-	subscribersMu sync.RWMutex
-	subscribers   map[chan<- []byte]struct{}
+	// subscriberMessageBuffer controls the max number
+	// of messages that can be queued for a subscriber
+	// before it is kicked.
+	//
+	// Defaults to 16.
+	subscriberMessageBuffer int
+
+	// publishLimiter controls the rate limit applied to the publish endpoint.
+	//
+	// Defaults to one publish every 100ms with a burst of 8.
+	publishLimiter *rate.Limiter
+
+	// logf controls where logs are sent.
+	// Defaults to log.Printf.
+	logf func(f string, v ...interface{})
+
+	// serveMux routes the various endpoints to the appropriate handler.
+	serveMux http.ServeMux
+
+	subscribersMu sync.Mutex
+	subscribers   map[*subscriber]struct{}
+}
+
+// newChatServer constructs a chatServer with the defaults.
+func newChatServer() *chatServer {
+	cs := &chatServer{
+		subscriberMessageBuffer: 16,
+		logf:                    log.Printf,
+		subscribers:             make(map[*subscriber]struct{}),
+		publishLimiter:          rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
+	}
+	cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
+	cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
+	cs.serveMux.HandleFunc("/publish", cs.publishHandler)
+
+	return cs
+}
+
+// subscriber represents a subscriber.
+// Messages are sent on the msgs channel and if the client
+// cannot keep up with the messages, closeSlow is called.
+type subscriber struct {
+	msgs      chan []byte
+	closeSlow func()
+}
+
+func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	cs.serveMux.ServeHTTP(w, r)
 }
 
 // subscribeHandler accepts the WebSocket connection and then subscribes
@@ -24,7 +71,7 @@ type chatServer struct {
 func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
 	c, err := websocket.Accept(w, r, nil)
 	if err != nil {
-		log.Print(err)
+		cs.logf("%v", err)
 		return
 	}
 	defer c.Close(websocket.StatusInternalError, "")
@@ -38,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	if err != nil {
-		log.Print(err)
+		cs.logf("%v", err)
+		return
 	}
 }
 
@@ -49,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
 		return
 	}
-	body := io.LimitReader(r.Body, 8192)
+	body := http.MaxBytesReader(w, r.Body, 8192)
 	msg, err := ioutil.ReadAll(body)
 	if err != nil {
 		http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
@@ -57,11 +105,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
 	}
 
 	cs.publish(msg)
+
+	w.WriteHeader(http.StatusAccepted)
 }
 
 // subscribe subscribes the given WebSocket to all broadcast messages.
-// It creates a msgs chan with a buffer of 16 to give some room to slower
-// connections and then registers it. It then listens for all messages
+// It creates a subscriber with a buffered msgs chan to give some room to slower
+// connections and then registers the subscriber. It then listens for all messages
 // and writes them to the WebSocket. If the context is cancelled or
 // an error occurs, it returns and deletes the subscription.
 //
@@ -70,13 +120,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
 func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
 	ctx = c.CloseRead(ctx)
 
-	msgs := make(chan []byte, 16)
-	cs.addSubscriber(msgs)
-	defer cs.deleteSubscriber(msgs)
+	s := &subscriber{
+		msgs: make(chan []byte, cs.subscriberMessageBuffer),
+		closeSlow: func() {
+			c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
+		},
+	}
+	cs.addSubscriber(s)
+	defer cs.deleteSubscriber(s)
 
 	for {
 		select {
-		case msg := <-msgs:
+		case msg := <-s.msgs:
 			err := writeTimeout(ctx, time.Second*5, c, msg)
 			if err != nil {
 				return err
@@ -91,32 +146,31 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
 // It never blocks and so messages to slow subscribers
 // are dropped.
 func (cs *chatServer) publish(msg []byte) {
-	cs.subscribersMu.RLock()
-	defer cs.subscribersMu.RUnlock()
+	cs.subscribersMu.Lock()
+	defer cs.subscribersMu.Unlock()
+
+	cs.publishLimiter.Wait(context.Background())
 
-	for c := range cs.subscribers {
+	for s := range cs.subscribers {
 		select {
-		case c <- msg:
+		case s.msgs <- msg:
 		default:
+			go s.closeSlow()
 		}
 	}
 }
 
-// addSubscriber registers a subscriber with a channel
-// on which to send messages.
-func (cs *chatServer) addSubscriber(msgs chan<- []byte) {
+// addSubscriber registers a subscriber.
+func (cs *chatServer) addSubscriber(s *subscriber) {
 	cs.subscribersMu.Lock()
-	if cs.subscribers == nil {
-		cs.subscribers = make(map[chan<- []byte]struct{})
-	}
-	cs.subscribers[msgs] = struct{}{}
+	cs.subscribers[s] = struct{}{}
 	cs.subscribersMu.Unlock()
 }
 
-// deleteSubscriber deletes the subscriber with the given msgs channel.
-func (cs *chatServer) deleteSubscriber(msgs chan []byte) {
+// deleteSubscriber deletes the given subscriber.
+func (cs *chatServer) deleteSubscriber(s *subscriber) {
 	cs.subscribersMu.Lock()
-	delete(cs.subscribers, msgs)
+	delete(cs.subscribers, s)
 	cs.subscribersMu.Unlock()
 }
 
diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..491499ccee6d8669f695da478736d05ca04fa2af
--- /dev/null
+++ b/chat-example/chat_test.go
@@ -0,0 +1,282 @@
+// +build !js
+
+package main
+
+import (
+	"context"
+	"crypto/rand"
+	"fmt"
+	"math/big"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"sync"
+	"testing"
+	"time"
+
+	"golang.org/x/time/rate"
+
+	"nhooyr.io/websocket"
+)
+
+func Test_chatServer(t *testing.T) {
+	t.Parallel()
+
+	// This is a simple echo test with a single client.
+	// The client sends a message and ensures it receives
+	// it on its WebSocket.
+	t.Run("simple", func(t *testing.T) {
+		t.Parallel()
+
+		url, closeFn := setupTest(t)
+		defer closeFn()
+
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+		defer cancel()
+
+		cl, err := newClient(ctx, url)
+		assertSuccess(t, err)
+		defer cl.Close()
+
+		expMsg := randString(512)
+		err = cl.publish(ctx, expMsg)
+		assertSuccess(t, err)
+
+		msg, err := cl.nextMessage()
+		assertSuccess(t, err)
+
+		if expMsg != msg {
+			t.Fatalf("expected %v but got %v", expMsg, msg)
+		}
+	})
+
+	// This test is a complex concurrency test.
+	// 10 clients are started that send 128 different
+	// messages of max 128 bytes concurrently.
+	//
+	// The test verifies that every message is seen by ever client
+	// and no errors occur anywhere.
+	t.Run("concurrency", func(t *testing.T) {
+		t.Parallel()
+
+		const nmessages = 128
+		const maxMessageSize = 128
+		const nclients = 10
+
+		url, closeFn := setupTest(t)
+		defer closeFn()
+
+		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+		defer cancel()
+
+		var clients []*client
+		var clientMsgs []map[string]struct{}
+		for i := 0; i < nclients; i++ {
+			cl, err := newClient(ctx, url)
+			assertSuccess(t, err)
+			defer cl.Close()
+
+			clients = append(clients, cl)
+			clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize))
+		}
+
+		allMessages := make(map[string]struct{})
+		for _, msgs := range clientMsgs {
+			for m := range msgs {
+				allMessages[m] = struct{}{}
+			}
+		}
+
+		var wg sync.WaitGroup
+		for i, cl := range clients {
+			i := i
+			cl := cl
+
+			wg.Add(1)
+			go func() {
+				defer wg.Done()
+				err := cl.publishMsgs(ctx, clientMsgs[i])
+				if err != nil {
+					t.Errorf("client %d failed to publish all messages: %v", i, err)
+				}
+			}()
+
+			wg.Add(1)
+			go func() {
+				defer wg.Done()
+				err := testAllMessagesReceived(cl, nclients*nmessages, allMessages)
+				if err != nil {
+					t.Errorf("client %d failed to receive all messages: %v", i, err)
+				}
+			}()
+		}
+
+		wg.Wait()
+	})
+}
+
+// setupTest sets up chatServer that can be used
+// via the returned url.
+//
+// Defer closeFn to ensure everything is cleaned up at
+// the end of the test.
+//
+// chatServer logs will be logged via t.Logf.
+func setupTest(t *testing.T) (url string, closeFn func()) {
+	cs := newChatServer()
+	cs.logf = t.Logf
+
+	// To ensure tests run quickly under even -race.
+	cs.subscriberMessageBuffer = 4096
+	cs.publishLimiter.SetLimit(rate.Inf)
+
+	var g websocket.Grace
+	s := httptest.NewServer(g.Handler(cs))
+	return s.URL, func() {
+		s.Close()
+		g.Close()
+	}
+}
+
+// testAllMessagesReceived ensures that after n reads, all msgs in msgs
+// have been read.
+func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error {
+	msgs = cloneMessages(msgs)
+
+	for i := 0; i < n; i++ {
+		msg, err := cl.nextMessage()
+		if err != nil {
+			return err
+		}
+		delete(msgs, msg)
+	}
+
+	if len(msgs) != 0 {
+		return fmt.Errorf("did not receive all expected messages: %q", msgs)
+	}
+	return nil
+}
+
+func cloneMessages(msgs map[string]struct{}) map[string]struct{} {
+	msgs2 := make(map[string]struct{}, len(msgs))
+	for m := range msgs {
+		msgs2[m] = struct{}{}
+	}
+	return msgs2
+}
+
+func randMessages(n, maxMessageLength int) map[string]struct{} {
+	msgs := make(map[string]struct{})
+	for i := 0; i < n; i++ {
+		m := randString(randInt(maxMessageLength))
+		if _, ok := msgs[m]; ok {
+			i--
+			continue
+		}
+		msgs[m] = struct{}{}
+	}
+	return msgs
+}
+
+func assertSuccess(t *testing.T, err error) {
+	t.Helper()
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+type client struct {
+	url string
+	c   *websocket.Conn
+}
+
+func newClient(ctx context.Context, url string) (*client, error) {
+	wsURL := strings.Replace(url, "http://", "ws://", 1)
+	c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil)
+	if err != nil {
+		return nil, err
+	}
+
+	cl := &client{
+		url: url,
+		c:   c,
+	}
+
+	return cl, nil
+}
+
+func (cl *client) publish(ctx context.Context, msg string) (err error) {
+	defer func() {
+		if err != nil {
+			cl.c.Close(websocket.StatusInternalError, "publish failed")
+		}
+	}()
+
+	req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg))
+	resp, err := http.DefaultClient.Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusAccepted {
+		return fmt.Errorf("publish request failed: %v", resp.StatusCode)
+	}
+	return nil
+}
+
+func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error {
+	for m := range msgs {
+		err := cl.publish(ctx, m)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (cl *client) nextMessage() (string, error) {
+	typ, b, err := cl.c.Read(context.Background())
+	if err != nil {
+		return "", err
+	}
+
+	if typ != websocket.MessageText {
+		cl.c.Close(websocket.StatusUnsupportedData, "expected text message")
+		return "", fmt.Errorf("expected text message but got %v", typ)
+	}
+	return string(b), nil
+}
+
+func (cl *client) Close() error {
+	return cl.c.Close(websocket.StatusNormalClosure, "")
+}
+
+// randString generates a random string with length n.
+func randString(n int) string {
+	b := make([]byte, n)
+	_, err := rand.Reader.Read(b)
+	if err != nil {
+		panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
+	}
+
+	s := strings.ToValidUTF8(string(b), "_")
+	s = strings.ReplaceAll(s, "\x00", "_")
+	if len(s) > n {
+		return s[:n]
+	}
+	if len(s) < n {
+		// Pad with =
+		extra := n - len(s)
+		return s + strings.Repeat("=", extra)
+	}
+	return s
+}
+
+// randInt returns a randomly generated integer between [0, max).
+func randInt(max int) int {
+	x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
+	if err != nil {
+		panic(fmt.Sprintf("failed to get random int: %v", err))
+	}
+	return int(x.Int64())
+}
diff --git a/chat-example/go.mod b/chat-example/go.mod
deleted file mode 100644
index 34fa5a69cef29a336862718fdc82ad48057f8704..0000000000000000000000000000000000000000
--- a/chat-example/go.mod
+++ /dev/null
@@ -1,5 +0,0 @@
-module nhooyr.io/websocket/example-chat
-
-go 1.13
-
-require nhooyr.io/websocket v1.8.2
diff --git a/chat-example/go.sum b/chat-example/go.sum
index 0755fca5eb20b3bfae3ab488ab2a6a723352f490..e4bbd62d337c4edbd4e71bd740271cf2544a0467 100644
--- a/chat-example/go.sum
+++ b/chat-example/go.sum
@@ -1,12 +1,18 @@
+github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
 github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
+github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
 github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
+github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
 github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
+github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I=
 github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
 github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y=
 github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
 golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA=
-nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ=
diff --git a/chat-example/index.css b/chat-example/index.css
index 2980466285849e233c95bfa0a2c9c2339a185e1c..73a8e0f3af030e225106234ef634a3aea1e5be3d 100644
--- a/chat-example/index.css
+++ b/chat-example/index.css
@@ -5,7 +5,7 @@ body {
 
 #root {
   padding: 40px 20px;
-  max-width: 480px;
+  max-width: 600px;
   margin: auto;
   height: 100vh;
 
diff --git a/chat-example/index.js b/chat-example/index.js
index 8fb3dfb8a6521c9fa4e64cd663556e17ea250ce1..5868e7caeeeebbdfa6cbe5d0a4a8c82f9268ab98 100644
--- a/chat-example/index.js
+++ b/chat-example/index.js
@@ -7,8 +7,11 @@
     const conn = new WebSocket(`ws://${location.host}/subscribe`)
 
     conn.addEventListener("close", ev => {
-      console.info("websocket disconnected, reconnecting in 1000ms", ev)
-      setTimeout(dial, 1000)
+      appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true)
+      if (ev.code !== 1001) {
+        appendLog("Reconnecting in 1s", true)
+        setTimeout(dial, 1000)
+      }
     })
     conn.addEventListener("open", ev => {
       console.info("websocket connected")
@@ -34,17 +37,21 @@
   const messageInput = document.getElementById("message-input")
 
   // appendLog appends the passed text to messageLog.
-  function appendLog(text) {
+  function appendLog(text, error) {
     const p = document.createElement("p")
     // Adding a timestamp to each message makes the log easier to read.
     p.innerText = `${new Date().toLocaleTimeString()}: ${text}`
+    if (error) {
+      p.style.color = "red"
+      p.style.fontStyle = "bold"
+    }
     messageLog.append(p)
     return p
   }
   appendLog("Submit a message to get started!")
 
   // onsubmit publishes the message from the user when the form is submitted.
-  publishForm.onsubmit = ev => {
+  publishForm.onsubmit = async ev => {
     ev.preventDefault()
 
     const msg = messageInput.value
@@ -54,9 +61,16 @@
     messageInput.value = ""
 
     expectingMessage = true
-    fetch("/publish", {
-      method: "POST",
-      body: msg,
-    })
+    try {
+      const resp = await fetch("/publish", {
+        method: "POST",
+        body: msg,
+      })
+      if (resp.status !== 202) {
+        throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`)
+      }
+    } catch (err) {
+      appendLog(`Publish failed: ${err.message}`, true)
+    }
   }
 })()
diff --git a/chat-example/main.go b/chat-example/main.go
index 2a5209244445ee7443866418115635fdede8bc48..1b6f3266cab9d37d320f74a62572830c29bbd47b 100644
--- a/chat-example/main.go
+++ b/chat-example/main.go
@@ -1,12 +1,16 @@
 package main
 
 import (
+	"context"
 	"errors"
 	"log"
 	"net"
 	"net/http"
 	"os"
+	"os/signal"
 	"time"
+
+	"nhooyr.io/websocket"
 )
 
 func main() {
@@ -31,17 +35,32 @@ func run() error {
 	}
 	log.Printf("listening on http://%v", l.Addr())
 
-	var ws chatServer
-
-	m := http.NewServeMux()
-	m.Handle("/", http.FileServer(http.Dir(".")))
-	m.HandleFunc("/subscribe", ws.subscribeHandler)
-	m.HandleFunc("/publish", ws.publishHandler)
-
+	cs := newChatServer()
+	var g websocket.Grace
 	s := http.Server{
-		Handler:      m,
+		Handler:      g.Handler(cs),
 		ReadTimeout:  time.Second * 10,
 		WriteTimeout: time.Second * 10,
 	}
-	return s.Serve(l)
+	errc := make(chan error, 1)
+	go func() {
+		errc <- s.Serve(l)
+	}()
+
+	sigs := make(chan os.Signal, 1)
+	signal.Notify(sigs, os.Interrupt)
+	select {
+	case err := <-errc:
+		log.Printf("failed to serve: %v", err)
+	case sig := <-sigs:
+		log.Printf("terminating: %v", sig)
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+	defer cancel()
+
+	s.Shutdown(ctx)
+	g.Shutdown(ctx)
+
+	return nil
 }
diff --git a/ci/test.mk b/ci/test.mk
index c62a25b64834a1f913fcddff87ac48441edba133..291d6beb11502e5447c49374876fc72a3023750b 100644
--- a/ci/test.mk
+++ b/ci/test.mk
@@ -11,6 +11,7 @@ coveralls: gotest
 	goveralls -coverprofile=ci/out/coverage.prof
 
 gotest:
-	go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./...
+	go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./...
 	sed -i '/stringer\.go/d' ci/out/coverage.prof
 	sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
+	sed -i '/chat-example/d' ci/out/coverage.prof
diff --git a/conn_notjs.go b/conn_notjs.go
index bb2eb22f7dbad0e22a5b932fbbab50f7381c64ce..f604898ed45b37145d782833547946ecec19b106 100644
--- a/conn_notjs.go
+++ b/conn_notjs.go
@@ -33,6 +33,7 @@ type Conn struct {
 	flateThreshold int
 	br             *bufio.Reader
 	bw             *bufio.Writer
+	g              *Grace
 
 	readTimeout  chan context.Context
 	writeTimeout chan context.Context
@@ -138,6 +139,10 @@ func (c *Conn) close(err error) {
 	// closeErr.
 	c.rwc.Close()
 
+	if c.g != nil {
+		c.g.delConn(c)
+	}
+
 	go func() {
 		c.msgWriterState.close()
 
diff --git a/conn_test.go b/conn_test.go
index 28da3c0788213a08757e7f80250c8297c0cf0fe1..af4fa4c0baa71c886c72aabd071595b30c62bc1e 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -13,7 +13,6 @@ import (
 	"os"
 	"os/exec"
 	"strings"
-	"sync"
 	"testing"
 	"time"
 
@@ -272,11 +271,9 @@ func TestWasm(t *testing.T) {
 		t.Skip("skipping on CI")
 	}
 
-	var wg sync.WaitGroup
-	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		wg.Add(1)
-		defer wg.Done()
-
+	var g websocket.Grace
+	defer g.Close()
+	s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 			Subprotocols:       []string{"echo"},
 			InsecureSkipVerify: true,
@@ -294,8 +291,7 @@ func TestWasm(t *testing.T) {
 			t.Errorf("echo server failed: %v", err)
 			return
 		}
-	}))
-	defer wg.Wait()
+	})))
 	defer s.Close()
 
 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
diff --git a/example_echo_test.go b/example_echo_test.go
index cd195d2e1eea95e14109833272728ff171dcf963..0c0b84ea1f94570179eae2a54005f3f10dd46f29 100644
--- a/example_echo_test.go
+++ b/example_echo_test.go
@@ -31,13 +31,15 @@ func Example_echo() {
 	}
 	defer l.Close()
 
+	var g websocket.Grace
+	defer g.Close()
 	s := &http.Server{
-		Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			err := echoServer(w, r)
 			if err != nil {
 				log.Printf("echo server: %v", err)
 			}
-		}),
+		})),
 		ReadTimeout:  time.Second * 15,
 		WriteTimeout: time.Second * 15,
 	}
diff --git a/example_test.go b/example_test.go
index c56e53f354355b506c82851a4df32468aa253be5..462de3761044d4741a77bc2bd62ed45896a6e2b9 100644
--- a/example_test.go
+++ b/example_test.go
@@ -6,6 +6,8 @@ import (
 	"context"
 	"log"
 	"net/http"
+	"os"
+	"os/signal"
 	"time"
 
 	"nhooyr.io/websocket"
@@ -133,3 +135,55 @@ func Example_crossOrigin() {
 	err := http.ListenAndServe("localhost:8080", fn)
 	log.Fatal(err)
 }
+
+// This example demonstrates how to create a WebSocket server
+// that gracefully exits when sent a signal.
+//
+// It starts a WebSocket server that keeps every connection open
+// for 10 seconds.
+// If you CTRL+C while a connection is open, it will wait at most 30s
+// for all connections to terminate before shutting down.
+func ExampleGrace() {
+	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		c, err := websocket.Accept(w, r, nil)
+		if err != nil {
+			log.Println(err)
+			return
+		}
+		defer c.Close(websocket.StatusInternalError, "the sky is falling")
+
+		ctx := c.CloseRead(r.Context())
+		select {
+		case <-ctx.Done():
+		case <-time.After(time.Second * 10):
+		}
+
+		c.Close(websocket.StatusNormalClosure, "")
+	})
+
+	var g websocket.Grace
+	s := &http.Server{
+		Handler:      g.Handler(fn),
+		ReadTimeout:  time.Second * 15,
+		WriteTimeout: time.Second * 15,
+	}
+
+	errc := make(chan error, 1)
+	go func() {
+		errc <- s.ListenAndServe()
+	}()
+
+	sigs := make(chan os.Signal, 1)
+	signal.Notify(sigs, os.Interrupt)
+	select {
+	case err := <-errc:
+		log.Printf("failed to listen and serve: %v", err)
+	case sig := <-sigs:
+		log.Printf("terminating: %v", sig)
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
+	defer cancel()
+	s.Shutdown(ctx)
+	g.Shutdown(ctx)
+}
diff --git a/grace.go b/grace.go
new file mode 100644
index 0000000000000000000000000000000000000000..c53cd40beb6e0ec00231425ff5939d3ab9b9d46a
--- /dev/null
+++ b/grace.go
@@ -0,0 +1,127 @@
+package websocket
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net/http"
+	"sync"
+	"time"
+)
+
+// Grace enables graceful shutdown of accepted WebSocket connections.
+//
+// Use Handler to wrap WebSocket handlers to record accepted connections
+// and then use Close or Shutdown to gracefully close these connections.
+//
+// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
+// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket
+// connections.
+type Grace struct {
+	mu           sync.Mutex
+	closed       bool
+	shuttingDown bool
+	conns        map[*Conn]struct{}
+}
+
+// Handler returns a handler that wraps around h to record
+// all WebSocket connections accepted.
+//
+// Use Close or Shutdown to gracefully close recorded connections.
+func (g *Grace) Handler(h http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ctx := context.WithValue(r.Context(), gracefulContextKey{}, g)
+		r = r.WithContext(ctx)
+		h.ServeHTTP(w, r)
+	})
+}
+
+func (g *Grace) isShuttingdown() bool {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	return g.shuttingDown
+}
+
+func graceFromRequest(r *http.Request) *Grace {
+	g, _ := r.Context().Value(gracefulContextKey{}).(*Grace)
+	return g
+}
+
+func (g *Grace) addConn(c *Conn) error {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	if g.closed {
+		c.Close(StatusGoingAway, "server shutting down")
+		return errors.New("server shutting down")
+	}
+	if g.conns == nil {
+		g.conns = make(map[*Conn]struct{})
+	}
+	g.conns[c] = struct{}{}
+	c.g = g
+	return nil
+}
+
+func (g *Grace) delConn(c *Conn) {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	delete(g.conns, c)
+}
+
+type gracefulContextKey struct{}
+
+// Close prevents the acceptance of new connections with
+// http.StatusServiceUnavailable and closes all accepted
+// connections with StatusGoingAway.
+func (g *Grace) Close() error {
+	g.mu.Lock()
+	g.shuttingDown = true
+	g.closed = true
+	var wg sync.WaitGroup
+	for c := range g.conns {
+		wg.Add(1)
+		go func(c *Conn) {
+			defer wg.Done()
+			c.Close(StatusGoingAway, "server shutting down")
+		}(c)
+
+		delete(g.conns, c)
+	}
+	g.mu.Unlock()
+
+	wg.Wait()
+
+	return nil
+}
+
+// Shutdown prevents the acceptance of new connections and waits until
+// all connections close. If the context is cancelled before that, it
+// calls Close to close all connections immediately.
+func (g *Grace) Shutdown(ctx context.Context) error {
+	defer g.Close()
+
+	g.mu.Lock()
+	g.shuttingDown = true
+	g.mu.Unlock()
+
+	// Same poll period used by net/http.
+	t := time.NewTicker(500 * time.Millisecond)
+	defer t.Stop()
+	for {
+		if g.zeroConns() {
+			return nil
+		}
+
+		select {
+		case <-t.C:
+		case <-ctx.Done():
+			return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err())
+		}
+	}
+}
+
+func (g *Grace) zeroConns() bool {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	return len(g.conns) == 0
+}
diff --git a/ws_js.go b/ws_js.go
index 2b560ce87d93035a4e81e167a844d301ae4f1af4..a8c8b77187d956b0eb27376fb1926a7973124f20 100644
--- a/ws_js.go
+++ b/ws_js.go
@@ -38,6 +38,8 @@ type Conn struct {
 	readSignal chan struct{}
 	readBufMu  sync.Mutex
 	readBuf    []wsjs.MessageEvent
+
+	g *Grace
 }
 
 func (c *Conn) close(err error, wasClean bool) {