diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b07c54b8854de084a863df77470795bb535c453a..a53a4697696211f742045a7914dcd775bd04f4ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,24 +7,18 @@ jobs: container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/fmt.sh lint: runs-on: ubuntu-latest container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/lint.sh test: runs-on: ubuntu-latest container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/test.sh env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} @@ -33,6 +27,4 @@ jobs: container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/wasm.sh diff --git a/assert_test.go b/assert_test.go index 8970c5437588bd835fc4cbf89b79d50790d28879..e67ed539f5717ce579649e0dd2334a80908d3d1a 100644 --- a/assert_test.go +++ b/assert_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "reflect" "strings" + "time" "github.com/google/go-cmp/cmp" @@ -13,6 +14,10 @@ import ( "nhooyr.io/websocket/wsjson" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + // https://github.com/google/go-cmp/issues/40#issuecomment-328615283 func cmpDiff(exp, act interface{}) string { return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) diff --git a/ci/wasm.sh b/ci/wasm.sh index 134b60b51e9247965d071187bb51d61396447716..c1d9a4045cda479e9c558a0ed0581d33de143141 100755 --- a/ci/wasm.sh +++ b/ci/wasm.sh @@ -25,7 +25,7 @@ go install github.com/agnivade/wasmbrowsertest export WS_ECHO_SERVER_URL GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./... -kill "$wsjstestPID" +kill "$wsjstestPID" || true if ! wait "$wsjstestPID"; then echo "--- wsjstest exited unsuccessfully" echo "output:" diff --git a/conn.go b/conn.go index 37c4cac21c2bc894fc2c8fa353656c38006f8b8c..d74b87538ff5a5cfdf476e4c050a0c56a841df8e 100644 --- a/conn.go +++ b/conn.go @@ -5,13 +5,12 @@ package websocket import ( "bufio" "context" - cryptorand "crypto/rand" + "crypto/rand" "errors" "fmt" "io" "io/ioutil" "log" - "math/rand" "runtime" "strconv" "sync" @@ -82,6 +81,7 @@ type Conn struct { setReadTimeout chan context.Context setWriteTimeout chan context.Context + pingCounter *atomicInt64 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } @@ -100,6 +100,7 @@ func (c *Conn) init() { c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) + c.pingCounter = &atomicInt64{} c.activePings = make(map[string]chan<- struct{}) c.writeHeaderBuf = makeWriteHeaderBuf() @@ -669,7 +670,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte c.writeHeader.payloadLength = int64(len(p)) if c.client { - _, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:]) + _, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:]) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } @@ -839,10 +840,6 @@ func (c *Conn) writeClose(p []byte, cerr error) error { return nil } -func init() { - rand.Seed(time.Now().UnixNano()) -} - // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does @@ -851,10 +848,9 @@ func init() { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - id := rand.Uint64() - p := strconv.FormatUint(id, 10) + p := c.pingCounter.Increment(1) - err := c.ping(ctx, p) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) } diff --git a/conn_common.go b/conn_common.go index e7a01035e11d45471d43a5fb42d2f7c83f6771d4..8233e4a68d0a241538a417868ddf2f19c8139cb5 100644 --- a/conn_common.go +++ b/conn_common.go @@ -211,21 +211,22 @@ func (c *Conn) setCloseErr(err error) { // See https://github.com/nhooyr/websocket/issues/153 type atomicInt64 struct { - v atomic.Value + v int64 } func (v *atomicInt64) Load() int64 { - i, ok := v.v.Load().(int64) - if !ok { - return 0 - } - return i + return atomic.LoadInt64(&v.v) } func (v *atomicInt64) Store(i int64) { - v.v.Store(i) + atomic.StoreInt64(&v.v, i) } func (v *atomicInt64) String() string { - return fmt.Sprint(v.v.Load()) + return fmt.Sprint(v.Load()) +} + +// Increment increments the value and returns the new value. +func (v *atomicInt64) Increment(delta int64) int64 { + return atomic.AddInt64(&v.v, delta) } diff --git a/conn_test.go b/conn_test.go index 8846979d5dc639836533f315f2ec90d005b412e8..c948c435d3ab78e7cd07446e46296aaf6f724484 100644 --- a/conn_test.go +++ b/conn_test.go @@ -37,6 +37,10 @@ import ( "nhooyr.io/websocket/wspb" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + func TestHandshake(t *testing.T) { t.Parallel() @@ -911,10 +915,6 @@ func TestConn(t *testing.T) { } } -func init() { - rand.Seed(time.Now().UnixNano()) -} - func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { var conns int64 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/frame_test.go b/frame_test.go index 1a2054c12226a7729717e578d67a5848844ce0fe..7d2a571958633a4ae9e2b2fb9bfd4149bcc66b64 100644 --- a/frame_test.go +++ b/frame_test.go @@ -10,10 +10,15 @@ import ( "strconv" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + func randBool() bool { return rand.Intn(1) == 0 } diff --git a/handshake.go b/handshake.go index 0b078085e8bb2b551b6a92446d4a8d78728dc341..d1a9fba4662f22fa8d0ba4361e02139494669373 100644 --- a/handshake.go +++ b/handshake.go @@ -6,13 +6,13 @@ import ( "bufio" "bytes" "context" + "crypto/rand" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "io/ioutil" - "math/rand" "net/http" "net/textproto" "net/url" @@ -299,7 +299,11 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey()) + secWebSocketKey, err := makeSecWebSocketKey() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } @@ -403,8 +407,11 @@ func returnBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } -func makeSecWebSocketKey() string { +func makeSecWebSocketKey() (string, error) { b := make([]byte, 16) - rand.Read(b) - return base64.StdEncoding.EncodeToString(b) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil } diff --git a/handshake_test.go b/handshake_test.go index a3d98163f3d9bab3d88321a69f814d8f0768cb11..cb09353f65d3928630aad627684b046f0b5ddd92 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -367,14 +367,17 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key := makeSecWebSocketKey() + key, err := makeSecWebSocketKey() + if err != nil { + t.Fatal(err) + } r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - err := verifyServerResponse(r, resp) + err = verifyServerResponse(r, resp) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) }