diff --git a/chat-example/chat.go b/chat-example/chat.go index e6e355d04688bd37f7c2495a5483c794d4453a55..9b264195c6d2871643521cfb7a7be749d8113fc2 100644 --- a/chat-example/chat.go +++ b/chat-example/chat.go @@ -15,8 +15,28 @@ import ( // chatServer enables broadcasting to a set of subscribers. type chatServer struct { + registerOnce sync.Once + m http.ServeMux + subscribersMu sync.RWMutex - subscribers map[chan<- []byte]struct{} + subscribers map[*subscriber]struct{} +} + +// subscriber represents a subscriber. +// Messages are sent on the msgs channel and if the client +// cannot keep up with the messages, closeSlow is called. +type subscriber struct { + msgs chan []byte + closeSlow func() +} + +func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cs.registerOnce.Do(func() { + cs.m.Handle("/", http.FileServer(http.Dir("."))) + cs.m.HandleFunc("/subscribe", cs.subscribeHandler) + cs.m.HandleFunc("/publish", cs.publishHandler) + }) + cs.m.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes @@ -57,11 +77,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { } cs.publish(msg) + + w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. -// It creates a msgs chan with a buffer of 16 to give some room to slower -// connections and then registers it. It then listens for all messages +// It creates a subscriber with a buffered msgs chan to give some room to slower +// connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // @@ -70,13 +92,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) - msgs := make(chan []byte, 16) - cs.addSubscriber(msgs) - defer cs.deleteSubscriber(msgs) + s := &subscriber{ + msgs: make(chan []byte, 16), + closeSlow: func() { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + }, + } + cs.addSubscriber(s) + defer cs.deleteSubscriber(s) for { select { - case msg := <-msgs: + case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err @@ -94,29 +121,29 @@ func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.RLock() defer cs.subscribersMu.RUnlock() - for c := range cs.subscribers { + for s := range cs.subscribers { select { - case c <- msg: + case s.msgs <- msg: default: + go s.closeSlow() } } } -// addSubscriber registers a subscriber with a channel -// on which to send messages. -func (cs *chatServer) addSubscriber(msgs chan<- []byte) { +// addSubscriber registers a subscriber. +func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() if cs.subscribers == nil { - cs.subscribers = make(map[chan<- []byte]struct{}) + cs.subscribers = make(map[*subscriber]struct{}) } - cs.subscribers[msgs] = struct{}{} + cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } -// deleteSubscriber deletes the subscriber with the given msgs channel. -func (cs *chatServer) deleteSubscriber(msgs chan []byte) { +// deleteSubscriber deletes the given subscriber. +func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() - delete(cs.subscribers, msgs) + delete(cs.subscribers, s) cs.subscribersMu.Unlock() } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d17723816ad4621a717794849c16ef00fe6f9bf9 --- /dev/null +++ b/chat-example/chat_test.go @@ -0,0 +1,137 @@ +// +build !js + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestGrace(t *testing.T) { + t.Parallel() + + var cs chatServer + var g websocket.Grace + s := httptest.NewServer(g.Handler(&cs)) + defer s.Close() + defer g.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + cl1, err := newClient(ctx, s.URL) + assertSuccess(t, err) + defer cl1.Close() + + cl2, err := newClient(ctx, s.URL) + assertSuccess(t, err) + defer cl2.Close() + + err = cl1.publish(ctx, "hello") + assertSuccess(t, err) + + assertReceivedMessage(ctx, cl1, "hello") + assertReceivedMessage(ctx, cl2, "hello") +} + +type client struct { + msgs chan string + url string + c *websocket.Conn +} + +func newClient(ctx context.Context, url string) (*client, error) { + wsURL := strings.ReplaceAll(url, "http://", "ws://") + c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) + if err != nil { + return nil, err + } + + cl := &client{ + msgs: make(chan string, 16), + url: url, + c: c, + } + go cl.readLoop() + + return cl, nil +} + +func (cl *client) readLoop() { + defer cl.c.Close(websocket.StatusInternalError, "") + defer close(cl.msgs) + + for { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return + } + + select { + case cl.msgs <- string(b): + default: + cl.c.Close(websocket.StatusInternalError, "messages coming in too fast to handle") + return + } + } +} + +func (cl *client) receive(ctx context.Context) (string, error) { + select { + case msg, ok := <-cl.msgs: + if !ok { + return "", errors.New("client closed") + } + return msg, nil + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func (cl *client) publish(ctx context.Context, msg string) error { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("publish request failed: %v", resp.StatusCode) + } + return nil +} + +func (cl *client) Close() error { + return cl.c.Close(websocket.StatusNormalClosure, "") +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func assertReceivedMessage(ctx context.Context, cl *client, msg string) error { + msg, err := cl.receive(ctx) + if err != nil { + return err + } + if msg != "hello" { + return fmt.Errorf("expected hello but got %q", msg) + } + return nil +} diff --git a/chat-example/go.mod b/chat-example/go.mod deleted file mode 100644 index c47a5a2ff66f448ab934a6a4f9ff0f9d346ce32c..0000000000000000000000000000000000000000 --- a/chat-example/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module nhooyr.io/websocket/example-chat - -go 1.13 - -require nhooyr.io/websocket v0.0.0 - -replace nhooyr.io/websocket => ../ diff --git a/chat-example/main.go b/chat-example/main.go index f985d3828f0fb7f3c961f6ad11dbca616e58e539..a265f60ce7c07243589050bc46141e32f1f73ab9 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -35,16 +35,10 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var ws chatServer - - m := http.NewServeMux() - m.Handle("/", http.FileServer(http.Dir("."))) - m.HandleFunc("/subscribe", ws.subscribeHandler) - m.HandleFunc("/publish", ws.publishHandler) - + var cs chatServer var g websocket.Grace s := http.Server{ - Handler: g.Handler(m), + Handler: g.Handler(&cs), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, }