From 97172f3339a9bef16fa82fde84b4b0c7a1357e56 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Tue, 25 Feb 2020 23:59:57 -0500 Subject: [PATCH] Add Grace to gracefully close WebSocket connections Closes #199 --- accept.go | 20 ++++++- conn_notjs.go | 5 ++ conn_test.go | 12 ++--- example_echo_test.go | 6 ++- example_test.go | 46 ++++++++++++++++ grace.go | 123 +++++++++++++++++++++++++++++++++++++++++++ ws_js.go | 2 + 7 files changed, 202 insertions(+), 12 deletions(-) create mode 100644 grace.go diff --git a/accept.go b/accept.go index 47e20b5..52a9345 100644 --- a/accept.go +++ b/accept.go @@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") + g := graceFromRequest(r) + if g != nil && g.isClosing() { + err := errors.New("server closing") + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return nil, err + } + if opts == nil { opts = &AcceptOptions{} } @@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - return newConn(connConfig{ + c := newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, @@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con br: brw.Reader, bw: brw.Writer, - }), nil + }) + + if g != nil { + err = g.addConn(c) + if err != nil { + return nil, err + } + } + + return c, nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { diff --git a/conn_notjs.go b/conn_notjs.go index bb2eb22..f604898 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -33,6 +33,7 @@ type Conn struct { flateThreshold int br *bufio.Reader bw *bufio.Writer + g *Grace readTimeout chan context.Context writeTimeout chan context.Context @@ -138,6 +139,10 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() + if c.g != nil { + c.g.delConn(c) + } + go func() { c.msgWriterState.close() diff --git a/conn_test.go b/conn_test.go index 28da3c0..af4fa4c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,7 +13,6 @@ import ( "os" "os/exec" "strings" - "sync" "testing" "time" @@ -272,11 +271,9 @@ func TestWasm(t *testing.T) { t.Skip("skipping on CI") } - var wg sync.WaitGroup - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wg.Add(1) - defer wg.Done() - + var g websocket.Grace + defer g.Close() + s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, @@ -294,8 +291,7 @@ func TestWasm(t *testing.T) { t.Errorf("echo server failed: %v", err) return } - })) - defer wg.Wait() + }))) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) diff --git a/example_echo_test.go b/example_echo_test.go index cd195d2..0c0b84e 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -31,13 +31,15 @@ func Example_echo() { } defer l.Close() + var g websocket.Grace + defer g.Close() s := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r) if err != nil { log.Printf("echo server: %v", err) } - }), + })), ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } diff --git a/example_test.go b/example_test.go index c56e53f..ce049bc 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,8 @@ import ( "context" "log" "net/http" + "os" + "os/signal" "time" "nhooyr.io/websocket" @@ -133,3 +135,47 @@ func Example_crossOrigin() { err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } + +// This example demonstrates how to create a WebSocket server +// that gracefully exits when sent a signal. +// +// It starts a WebSocket server that keeps every connection open +// for 10 seconds. +// If you CTRL+C while a connection is open, it will wait at most 30s +// for all connections to terminate before shutting down. +func ExampleGrace() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx := c.CloseRead(r.Context()) + select { + case <-ctx.Done(): + case <-time.After(time.Second * 10): + } + + c.Close(websocket.StatusNormalClosure, "") + }) + + var g websocket.Grace + s := &http.Server{ + Handler: g.Handler(fn), + ReadTimeout: time.Second * 15, + WriteTimeout: time.Second * 15, + } + go s.ListenAndServe() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + sig := <-sigs + log.Printf("recieved %v, shutting down", sig) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + s.Shutdown(ctx) + g.Shutdown(ctx) +} diff --git a/grace.go b/grace.go new file mode 100644 index 0000000..8dadc43 --- /dev/null +++ b/grace.go @@ -0,0 +1,123 @@ +package websocket + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "time" +) + +// Grace enables graceful shutdown of accepted WebSocket connections. +// +// Use Handler to wrap WebSocket handlers to record accepted connections +// and then use Close or Shutdown to gracefully close these connections. +// +// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. +type Grace struct { + mu sync.Mutex + closing bool + conns map[*Conn]struct{} +} + +// Handler returns a handler that wraps around h to record +// all WebSocket connections accepted. +// +// Use Close or Shutdown to gracefully close recorded connections. +func (g *Grace) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), gracefulContextKey{}, g) + r = r.WithContext(ctx) + h.ServeHTTP(w, r) + }) +} + +func (g *Grace) isClosing() bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.closing +} + +func graceFromRequest(r *http.Request) *Grace { + g, _ := r.Context().Value(gracefulContextKey{}).(*Grace) + return g +} + +func (g *Grace) addConn(c *Conn) error { + g.mu.Lock() + defer g.mu.Unlock() + if g.closing { + c.Close(StatusGoingAway, "server shutting down") + return errors.New("server shutting down") + } + if g.conns == nil { + g.conns = make(map[*Conn]struct{}) + } + g.conns[c] = struct{}{} + c.g = g + return nil +} + +func (g *Grace) delConn(c *Conn) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.conns, c) +} + +type gracefulContextKey struct{} + +// Close prevents the acceptance of new connections with +// http.StatusServiceUnavailable and closes all accepted +// connections with StatusGoingAway. +func (g *Grace) Close() error { + g.mu.Lock() + g.closing = true + var wg sync.WaitGroup + for c := range g.conns { + wg.Add(1) + go func(c *Conn) { + defer wg.Done() + c.Close(StatusGoingAway, "server shutting down") + }(c) + + delete(g.conns, c) + } + g.mu.Unlock() + + wg.Wait() + + return nil +} + +// Shutdown prevents the acceptance of new connections and waits until +// all connections close. If the context is cancelled before that, it +// calls Close to close all connections immediately. +func (g *Grace) Shutdown(ctx context.Context) error { + defer g.Close() + + g.mu.Lock() + g.closing = true + g.mu.Unlock() + + // Same poll period used by net/http. + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + if g.zeroConns() { + return nil + } + + select { + case <-t.C: + case <-ctx.Done(): + return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err()) + } + } +} + +func (g *Grace) zeroConns() bool { + g.mu.Lock() + defer g.mu.Unlock() + return len(g.conns) == 0 +} diff --git a/ws_js.go b/ws_js.go index 2b560ce..a8c8b77 100644 --- a/ws_js.go +++ b/ws_js.go @@ -38,6 +38,8 @@ type Conn struct { readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent + + g *Grace } func (c *Conn) close(err error, wasClean bool) { -- GitLab