diff --git a/accept.go b/accept.go index 47e20b52c373f6757475c86c30b06626f3dfc745..dd96c9bd96c2fccaebbaf535e4ce53bbe1f8279f 100644 --- a/accept.go +++ b/accept.go @@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") + g := graceFromRequest(r) + if g != nil && g.isShuttingdown() { + err := errors.New("server shutting down") + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return nil, err + } + if opts == nil { opts = &AcceptOptions{} } @@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - return newConn(connConfig{ + c := newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, @@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con br: brw.Reader, bw: brw.Writer, - }), nil + }) + + if g != nil { + err = g.addConn(c) + if err != nil { + return nil, err + } + } + + return c, nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { diff --git a/chat-example/README.md b/chat-example/README.md index ef06275db3bd1bb6ff2c8c85033e144a2b892965..a4c99a93447f0c4eecacdd48b17a0f285c680de6 100644 --- a/chat-example/README.md +++ b/chat-example/README.md @@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. + +There are two automated tests for the server included in `chat_test.go`. The first is a simple one +client echo test. It publishes a single message and ensures it's received. + +The second is a complex concurrency test where 10 clients send 128 unique messages +of max 128 bytes concurrently. The test ensures all messages are seen by every client. diff --git a/chat-example/chat.go b/chat-example/chat.go index e6e355d04688bd37f7c2495a5483c794d4453a55..532e50f544fbc39f2ca0a870b0c9cc7ef8ae8e44 100644 --- a/chat-example/chat.go +++ b/chat-example/chat.go @@ -3,20 +3,67 @@ package main import ( "context" "errors" - "io" "io/ioutil" "log" "net/http" "sync" "time" + "golang.org/x/time/rate" + "nhooyr.io/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { - subscribersMu sync.RWMutex - subscribers map[chan<- []byte]struct{} + // subscriberMessageBuffer controls the max number + // of messages that can be queued for a subscriber + // before it is kicked. + // + // Defaults to 16. + subscriberMessageBuffer int + + // publishLimiter controls the rate limit applied to the publish endpoint. + // + // Defaults to one publish every 100ms with a burst of 8. + publishLimiter *rate.Limiter + + // logf controls where logs are sent. + // Defaults to log.Printf. + logf func(f string, v ...interface{}) + + // serveMux routes the various endpoints to the appropriate handler. + serveMux http.ServeMux + + subscribersMu sync.Mutex + subscribers map[*subscriber]struct{} +} + +// newChatServer constructs a chatServer with the defaults. +func newChatServer() *chatServer { + cs := &chatServer{ + subscriberMessageBuffer: 16, + logf: log.Printf, + subscribers: make(map[*subscriber]struct{}), + publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), + } + cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) + cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) + cs.serveMux.HandleFunc("/publish", cs.publishHandler) + + return cs +} + +// subscriber represents a subscriber. +// Messages are sent on the msgs channel and if the client +// cannot keep up with the messages, closeSlow is called. +type subscriber struct { + msgs chan []byte + closeSlow func() +} + +func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes @@ -24,7 +71,7 @@ type chatServer struct { func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { - log.Print(err) + cs.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "") @@ -38,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { return } if err != nil { - log.Print(err) + cs.logf("%v", err) + return } } @@ -49,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } - body := io.LimitReader(r.Body, 8192) + body := http.MaxBytesReader(w, r.Body, 8192) msg, err := ioutil.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) @@ -57,11 +105,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { } cs.publish(msg) + + w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. -// It creates a msgs chan with a buffer of 16 to give some room to slower -// connections and then registers it. It then listens for all messages +// It creates a subscriber with a buffered msgs chan to give some room to slower +// connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // @@ -70,13 +120,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) - msgs := make(chan []byte, 16) - cs.addSubscriber(msgs) - defer cs.deleteSubscriber(msgs) + s := &subscriber{ + msgs: make(chan []byte, cs.subscriberMessageBuffer), + closeSlow: func() { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + }, + } + cs.addSubscriber(s) + defer cs.deleteSubscriber(s) for { select { - case msg := <-msgs: + case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err @@ -91,32 +146,31 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { - cs.subscribersMu.RLock() - defer cs.subscribersMu.RUnlock() + cs.subscribersMu.Lock() + defer cs.subscribersMu.Unlock() + + cs.publishLimiter.Wait(context.Background()) - for c := range cs.subscribers { + for s := range cs.subscribers { select { - case c <- msg: + case s.msgs <- msg: default: + go s.closeSlow() } } } -// addSubscriber registers a subscriber with a channel -// on which to send messages. -func (cs *chatServer) addSubscriber(msgs chan<- []byte) { +// addSubscriber registers a subscriber. +func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() - if cs.subscribers == nil { - cs.subscribers = make(map[chan<- []byte]struct{}) - } - cs.subscribers[msgs] = struct{}{} + cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } -// deleteSubscriber deletes the subscriber with the given msgs channel. -func (cs *chatServer) deleteSubscriber(msgs chan []byte) { +// deleteSubscriber deletes the given subscriber. +func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() - delete(cs.subscribers, msgs) + delete(cs.subscribers, s) cs.subscribersMu.Unlock() } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..491499ccee6d8669f695da478736d05ca04fa2af --- /dev/null +++ b/chat-example/chat_test.go @@ -0,0 +1,282 @@ +// +build !js + +package main + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/time/rate" + + "nhooyr.io/websocket" +) + +func Test_chatServer(t *testing.T) { + t.Parallel() + + // This is a simple echo test with a single client. + // The client sends a message and ensures it receives + // it on its WebSocket. + t.Run("simple", func(t *testing.T) { + t.Parallel() + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + expMsg := randString(512) + err = cl.publish(ctx, expMsg) + assertSuccess(t, err) + + msg, err := cl.nextMessage() + assertSuccess(t, err) + + if expMsg != msg { + t.Fatalf("expected %v but got %v", expMsg, msg) + } + }) + + // This test is a complex concurrency test. + // 10 clients are started that send 128 different + // messages of max 128 bytes concurrently. + // + // The test verifies that every message is seen by ever client + // and no errors occur anywhere. + t.Run("concurrency", func(t *testing.T) { + t.Parallel() + + const nmessages = 128 + const maxMessageSize = 128 + const nclients = 10 + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var clients []*client + var clientMsgs []map[string]struct{} + for i := 0; i < nclients; i++ { + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + clients = append(clients, cl) + clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) + } + + allMessages := make(map[string]struct{}) + for _, msgs := range clientMsgs { + for m := range msgs { + allMessages[m] = struct{}{} + } + } + + var wg sync.WaitGroup + for i, cl := range clients { + i := i + cl := cl + + wg.Add(1) + go func() { + defer wg.Done() + err := cl.publishMsgs(ctx, clientMsgs[i]) + if err != nil { + t.Errorf("client %d failed to publish all messages: %v", i, err) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) + if err != nil { + t.Errorf("client %d failed to receive all messages: %v", i, err) + } + }() + } + + wg.Wait() + }) +} + +// setupTest sets up chatServer that can be used +// via the returned url. +// +// Defer closeFn to ensure everything is cleaned up at +// the end of the test. +// +// chatServer logs will be logged via t.Logf. +func setupTest(t *testing.T) (url string, closeFn func()) { + cs := newChatServer() + cs.logf = t.Logf + + // To ensure tests run quickly under even -race. + cs.subscriberMessageBuffer = 4096 + cs.publishLimiter.SetLimit(rate.Inf) + + var g websocket.Grace + s := httptest.NewServer(g.Handler(cs)) + return s.URL, func() { + s.Close() + g.Close() + } +} + +// testAllMessagesReceived ensures that after n reads, all msgs in msgs +// have been read. +func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { + msgs = cloneMessages(msgs) + + for i := 0; i < n; i++ { + msg, err := cl.nextMessage() + if err != nil { + return err + } + delete(msgs, msg) + } + + if len(msgs) != 0 { + return fmt.Errorf("did not receive all expected messages: %q", msgs) + } + return nil +} + +func cloneMessages(msgs map[string]struct{}) map[string]struct{} { + msgs2 := make(map[string]struct{}, len(msgs)) + for m := range msgs { + msgs2[m] = struct{}{} + } + return msgs2 +} + +func randMessages(n, maxMessageLength int) map[string]struct{} { + msgs := make(map[string]struct{}) + for i := 0; i < n; i++ { + m := randString(randInt(maxMessageLength)) + if _, ok := msgs[m]; ok { + i-- + continue + } + msgs[m] = struct{}{} + } + return msgs +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +type client struct { + url string + c *websocket.Conn +} + +func newClient(ctx context.Context, url string) (*client, error) { + wsURL := strings.Replace(url, "http://", "ws://", 1) + c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) + if err != nil { + return nil, err + } + + cl := &client{ + url: url, + c: c, + } + + return cl, nil +} + +func (cl *client) publish(ctx context.Context, msg string) (err error) { + defer func() { + if err != nil { + cl.c.Close(websocket.StatusInternalError, "publish failed") + } + }() + + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("publish request failed: %v", resp.StatusCode) + } + return nil +} + +func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { + for m := range msgs { + err := cl.publish(ctx, m) + if err != nil { + return err + } + } + return nil +} + +func (cl *client) nextMessage() (string, error) { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return "", err + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return "", fmt.Errorf("expected text message but got %v", typ) + } + return string(b), nil +} + +func (cl *client) Close() error { + return cl.c.Close(websocket.StatusNormalClosure, "") +} + +// randString generates a random string with length n. +func randString(n int) string { + b := make([]byte, n) + _, err := rand.Reader.Read(b) + if err != nil { + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + + s := strings.ToValidUTF8(string(b), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] + } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s +} + +// randInt returns a randomly generated integer between [0, max). +func randInt(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + panic(fmt.Sprintf("failed to get random int: %v", err)) + } + return int(x.Int64()) +} diff --git a/chat-example/go.mod b/chat-example/go.mod deleted file mode 100644 index 34fa5a69cef29a336862718fdc82ad48057f8704..0000000000000000000000000000000000000000 --- a/chat-example/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module nhooyr.io/websocket/example-chat - -go 1.13 - -require nhooyr.io/websocket v1.8.2 diff --git a/chat-example/go.sum b/chat-example/go.sum index 0755fca5eb20b3bfae3ab488ab2a6a723352f490..e4bbd62d337c4edbd4e71bd740271cf2544a0467 100644 --- a/chat-example/go.sum +++ b/chat-example/go.sum @@ -1,12 +1,18 @@ +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA= -nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ= diff --git a/chat-example/index.css b/chat-example/index.css index 2980466285849e233c95bfa0a2c9c2339a185e1c..73a8e0f3af030e225106234ef634a3aea1e5be3d 100644 --- a/chat-example/index.css +++ b/chat-example/index.css @@ -5,7 +5,7 @@ body { #root { padding: 40px 20px; - max-width: 480px; + max-width: 600px; margin: auto; height: 100vh; diff --git a/chat-example/index.js b/chat-example/index.js index 8fb3dfb8a6521c9fa4e64cd663556e17ea250ce1..5868e7caeeeebbdfa6cbe5d0a4a8c82f9268ab98 100644 --- a/chat-example/index.js +++ b/chat-example/index.js @@ -7,8 +7,11 @@ const conn = new WebSocket(`ws://${location.host}/subscribe`) conn.addEventListener("close", ev => { - console.info("websocket disconnected, reconnecting in 1000ms", ev) - setTimeout(dial, 1000) + appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) + if (ev.code !== 1001) { + appendLog("Reconnecting in 1s", true) + setTimeout(dial, 1000) + } }) conn.addEventListener("open", ev => { console.info("websocket connected") @@ -34,17 +37,21 @@ const messageInput = document.getElementById("message-input") // appendLog appends the passed text to messageLog. - function appendLog(text) { + function appendLog(text, error) { const p = document.createElement("p") // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` + if (error) { + p.style.color = "red" + p.style.fontStyle = "bold" + } messageLog.append(p) return p } appendLog("Submit a message to get started!") // onsubmit publishes the message from the user when the form is submitted. - publishForm.onsubmit = ev => { + publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value @@ -54,9 +61,16 @@ messageInput.value = "" expectingMessage = true - fetch("/publish", { - method: "POST", - body: msg, - }) + try { + const resp = await fetch("/publish", { + method: "POST", + body: msg, + }) + if (resp.status !== 202) { + throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) + } + } catch (err) { + appendLog(`Publish failed: ${err.message}`, true) + } } })() diff --git a/chat-example/main.go b/chat-example/main.go index 2a5209244445ee7443866418115635fdede8bc48..1b6f3266cab9d37d320f74a62572830c29bbd47b 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -1,12 +1,16 @@ package main import ( + "context" "errors" "log" "net" "net/http" "os" + "os/signal" "time" + + "nhooyr.io/websocket" ) func main() { @@ -31,17 +35,32 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var ws chatServer - - m := http.NewServeMux() - m.Handle("/", http.FileServer(http.Dir("."))) - m.HandleFunc("/subscribe", ws.subscribeHandler) - m.HandleFunc("/publish", ws.publishHandler) - + cs := newChatServer() + var g websocket.Grace s := http.Server{ - Handler: m, + Handler: g.Handler(cs), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } - return s.Serve(l) + errc := make(chan error, 1) + go func() { + errc <- s.Serve(l) + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case err := <-errc: + log.Printf("failed to serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + s.Shutdown(ctx) + g.Shutdown(ctx) + + return nil } diff --git a/ci/test.mk b/ci/test.mk index c62a25b64834a1f913fcddff87ac48441edba133..291d6beb11502e5447c49374876fc72a3023750b 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -11,6 +11,7 @@ coveralls: gotest goveralls -coverprofile=ci/out/coverage.prof gotest: - go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... + go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof + sed -i '/chat-example/d' ci/out/coverage.prof diff --git a/conn_notjs.go b/conn_notjs.go index bb2eb22f7dbad0e22a5b932fbbab50f7381c64ce..f604898ed45b37145d782833547946ecec19b106 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -33,6 +33,7 @@ type Conn struct { flateThreshold int br *bufio.Reader bw *bufio.Writer + g *Grace readTimeout chan context.Context writeTimeout chan context.Context @@ -138,6 +139,10 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() + if c.g != nil { + c.g.delConn(c) + } + go func() { c.msgWriterState.close() diff --git a/conn_test.go b/conn_test.go index 28da3c0788213a08757e7f80250c8297c0cf0fe1..af4fa4c0baa71c886c72aabd071595b30c62bc1e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,7 +13,6 @@ import ( "os" "os/exec" "strings" - "sync" "testing" "time" @@ -272,11 +271,9 @@ func TestWasm(t *testing.T) { t.Skip("skipping on CI") } - var wg sync.WaitGroup - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wg.Add(1) - defer wg.Done() - + var g websocket.Grace + defer g.Close() + s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, @@ -294,8 +291,7 @@ func TestWasm(t *testing.T) { t.Errorf("echo server failed: %v", err) return } - })) - defer wg.Wait() + }))) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) diff --git a/example_echo_test.go b/example_echo_test.go index cd195d2e1eea95e14109833272728ff171dcf963..0c0b84ea1f94570179eae2a54005f3f10dd46f29 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -31,13 +31,15 @@ func Example_echo() { } defer l.Close() + var g websocket.Grace + defer g.Close() s := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r) if err != nil { log.Printf("echo server: %v", err) } - }), + })), ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } diff --git a/example_test.go b/example_test.go index c56e53f354355b506c82851a4df32468aa253be5..462de3761044d4741a77bc2bd62ed45896a6e2b9 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,8 @@ import ( "context" "log" "net/http" + "os" + "os/signal" "time" "nhooyr.io/websocket" @@ -133,3 +135,55 @@ func Example_crossOrigin() { err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } + +// This example demonstrates how to create a WebSocket server +// that gracefully exits when sent a signal. +// +// It starts a WebSocket server that keeps every connection open +// for 10 seconds. +// If you CTRL+C while a connection is open, it will wait at most 30s +// for all connections to terminate before shutting down. +func ExampleGrace() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx := c.CloseRead(r.Context()) + select { + case <-ctx.Done(): + case <-time.After(time.Second * 10): + } + + c.Close(websocket.StatusNormalClosure, "") + }) + + var g websocket.Grace + s := &http.Server{ + Handler: g.Handler(fn), + ReadTimeout: time.Second * 15, + WriteTimeout: time.Second * 15, + } + + errc := make(chan error, 1) + go func() { + errc <- s.ListenAndServe() + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case err := <-errc: + log.Printf("failed to listen and serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + s.Shutdown(ctx) + g.Shutdown(ctx) +} diff --git a/grace.go b/grace.go new file mode 100644 index 0000000000000000000000000000000000000000..c53cd40beb6e0ec00231425ff5939d3ab9b9d46a --- /dev/null +++ b/grace.go @@ -0,0 +1,127 @@ +package websocket + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "time" +) + +// Grace enables graceful shutdown of accepted WebSocket connections. +// +// Use Handler to wrap WebSocket handlers to record accepted connections +// and then use Close or Shutdown to gracefully close these connections. +// +// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. +// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket +// connections. +type Grace struct { + mu sync.Mutex + closed bool + shuttingDown bool + conns map[*Conn]struct{} +} + +// Handler returns a handler that wraps around h to record +// all WebSocket connections accepted. +// +// Use Close or Shutdown to gracefully close recorded connections. +func (g *Grace) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), gracefulContextKey{}, g) + r = r.WithContext(ctx) + h.ServeHTTP(w, r) + }) +} + +func (g *Grace) isShuttingdown() bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.shuttingDown +} + +func graceFromRequest(r *http.Request) *Grace { + g, _ := r.Context().Value(gracefulContextKey{}).(*Grace) + return g +} + +func (g *Grace) addConn(c *Conn) error { + g.mu.Lock() + defer g.mu.Unlock() + if g.closed { + c.Close(StatusGoingAway, "server shutting down") + return errors.New("server shutting down") + } + if g.conns == nil { + g.conns = make(map[*Conn]struct{}) + } + g.conns[c] = struct{}{} + c.g = g + return nil +} + +func (g *Grace) delConn(c *Conn) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.conns, c) +} + +type gracefulContextKey struct{} + +// Close prevents the acceptance of new connections with +// http.StatusServiceUnavailable and closes all accepted +// connections with StatusGoingAway. +func (g *Grace) Close() error { + g.mu.Lock() + g.shuttingDown = true + g.closed = true + var wg sync.WaitGroup + for c := range g.conns { + wg.Add(1) + go func(c *Conn) { + defer wg.Done() + c.Close(StatusGoingAway, "server shutting down") + }(c) + + delete(g.conns, c) + } + g.mu.Unlock() + + wg.Wait() + + return nil +} + +// Shutdown prevents the acceptance of new connections and waits until +// all connections close. If the context is cancelled before that, it +// calls Close to close all connections immediately. +func (g *Grace) Shutdown(ctx context.Context) error { + defer g.Close() + + g.mu.Lock() + g.shuttingDown = true + g.mu.Unlock() + + // Same poll period used by net/http. + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + if g.zeroConns() { + return nil + } + + select { + case <-t.C: + case <-ctx.Done(): + return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err()) + } + } +} + +func (g *Grace) zeroConns() bool { + g.mu.Lock() + defer g.mu.Unlock() + return len(g.conns) == 0 +} diff --git a/ws_js.go b/ws_js.go index 2b560ce87d93035a4e81e167a844d301ae4f1af4..a8c8b77187d956b0eb27376fb1926a7973124f20 100644 --- a/ws_js.go +++ b/ws_js.go @@ -38,6 +38,8 @@ type Conn struct { readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent + + g *Grace } func (c *Conn) close(err error, wasClean bool) {