good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit af0fd9d4 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

examples/chat: Fix race condition

Tricky tricky.
parent ff3ea39b
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"sync" "sync"
"time" "time"
...@@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { ...@@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// subscribeHandler accepts the WebSocket connection and then subscribes // subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages. // it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil) err := cs.subscribe(r.Context(), w, r)
if err != nil {
cs.logf("%v", err)
return
}
defer c.CloseNow()
err = cs.subscribe(r.Context(), c)
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return return
} }
...@@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { ...@@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
// //
// It uses CloseRead to keep reading from the connection to process control // It uses CloseRead to keep reading from the connection to process control
// messages and cancel the context if the connection drops. // messages and cancel the context if the connection drops.
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
ctx = c.CloseRead(ctx) var mu sync.Mutex
var c *websocket.Conn
var closed bool
s := &subscriber{ s := &subscriber{
msgs: make(chan []byte, cs.subscriberMessageBuffer), msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() { closeSlow: func() {
mu.Lock()
defer mu.Unlock()
closed = true
if c != nil {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
}
}, },
} }
cs.addSubscriber(s) cs.addSubscriber(s)
defer cs.deleteSubscriber(s) defer cs.deleteSubscriber(s)
c2, err := websocket.Accept(w, r, nil)
if err != nil {
return err
}
mu.Lock()
if closed {
mu.Unlock()
return net.ErrClosed
}
c = c2
mu.Unlock()
defer c.CloseNow()
ctx = c.CloseRead(ctx)
for { for {
select { select {
case msg := <-s.msgs: case msg := <-s.msgs:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment