From e476358de061353e5069f490ac09dd3815513e01 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sun, 29 Sep 2019 18:21:31 -0500
Subject: [PATCH] Improve usage of math/rand versus crypto/rand

math/rand was being used inappropiately and did not have
a init function for every file it was used in.
---
 .github/workflows/ci.yml |  8 --------
 assert_test.go           |  5 +++++
 ci/wasm.sh               |  2 +-
 conn.go                  | 16 ++++++----------
 conn_common.go           | 17 +++++++++--------
 conn_test.go             |  8 ++++----
 frame_test.go            |  5 +++++
 handshake.go             | 17 ++++++++++++-----
 handshake_test.go        |  7 +++++--
 9 files changed, 47 insertions(+), 38 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index b07c54b..a53a469 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -7,24 +7,18 @@ jobs:
     container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
     steps:
       - uses: actions/checkout@v1
-        with:
-          fetch-depth: 1
       - run: ./ci/fmt.sh
   lint:
     runs-on: ubuntu-latest
     container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
     steps:
       - uses: actions/checkout@v1
-        with:
-          fetch-depth: 1
       - run: ./ci/lint.sh
   test:
     runs-on: ubuntu-latest
     container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
     steps:
       - uses: actions/checkout@v1
-        with:
-          fetch-depth: 1
       - run: ./ci/test.sh
         env:
           CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
@@ -33,6 +27,4 @@ jobs:
     container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
     steps:
       - uses: actions/checkout@v1
-        with:
-          fetch-depth: 1
       - run: ./ci/wasm.sh
diff --git a/assert_test.go b/assert_test.go
index 8970c54..e67ed53 100644
--- a/assert_test.go
+++ b/assert_test.go
@@ -6,6 +6,7 @@ import (
 	"math/rand"
 	"reflect"
 	"strings"
+	"time"
 
 	"github.com/google/go-cmp/cmp"
 
@@ -13,6 +14,10 @@ import (
 	"nhooyr.io/websocket/wsjson"
 )
 
+func init() {
+	rand.Seed(time.Now().UnixNano())
+}
+
 // https://github.com/google/go-cmp/issues/40#issuecomment-328615283
 func cmpDiff(exp, act interface{}) string {
 	return cmp.Diff(exp, act, deepAllowUnexported(exp, act))
diff --git a/ci/wasm.sh b/ci/wasm.sh
index 134b60b..c1d9a40 100755
--- a/ci/wasm.sh
+++ b/ci/wasm.sh
@@ -25,7 +25,7 @@ go install github.com/agnivade/wasmbrowsertest
 export WS_ECHO_SERVER_URL
 GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./...
 
-kill "$wsjstestPID"
+kill "$wsjstestPID" || true
 if ! wait "$wsjstestPID"; then
   echo "--- wsjstest exited unsuccessfully"
   echo "output:"
diff --git a/conn.go b/conn.go
index 37c4cac..d74b875 100644
--- a/conn.go
+++ b/conn.go
@@ -5,13 +5,12 @@ package websocket
 import (
 	"bufio"
 	"context"
-	cryptorand "crypto/rand"
+	"crypto/rand"
 	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
 	"log"
-	"math/rand"
 	"runtime"
 	"strconv"
 	"sync"
@@ -82,6 +81,7 @@ type Conn struct {
 	setReadTimeout  chan context.Context
 	setWriteTimeout chan context.Context
 
+	pingCounter   *atomicInt64
 	activePingsMu sync.Mutex
 	activePings   map[string]chan<- struct{}
 }
@@ -100,6 +100,7 @@ func (c *Conn) init() {
 	c.setReadTimeout = make(chan context.Context)
 	c.setWriteTimeout = make(chan context.Context)
 
+	c.pingCounter = &atomicInt64{}
 	c.activePings = make(map[string]chan<- struct{})
 
 	c.writeHeaderBuf = makeWriteHeaderBuf()
@@ -669,7 +670,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
 	c.writeHeader.payloadLength = int64(len(p))
 
 	if c.client {
-		_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
+		_, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:])
 		if err != nil {
 			return 0, fmt.Errorf("failed to generate masking key: %w", err)
 		}
@@ -839,10 +840,6 @@ func (c *Conn) writeClose(p []byte, cerr error) error {
 	return nil
 }
 
-func init() {
-	rand.Seed(time.Now().UnixNano())
-}
-
 // Ping sends a ping to the peer and waits for a pong.
 // Use this to measure latency or ensure the peer is responsive.
 // Ping must be called concurrently with Reader as it does
@@ -851,10 +848,9 @@ func init() {
 //
 // TCP Keepalives should suffice for most use cases.
 func (c *Conn) Ping(ctx context.Context) error {
-	id := rand.Uint64()
-	p := strconv.FormatUint(id, 10)
+	p := c.pingCounter.Increment(1)
 
-	err := c.ping(ctx, p)
+	err := c.ping(ctx, strconv.FormatInt(p, 10))
 	if err != nil {
 		return fmt.Errorf("failed to ping: %w", err)
 	}
diff --git a/conn_common.go b/conn_common.go
index e7a0103..8233e4a 100644
--- a/conn_common.go
+++ b/conn_common.go
@@ -211,21 +211,22 @@ func (c *Conn) setCloseErr(err error) {
 
 // See https://github.com/nhooyr/websocket/issues/153
 type atomicInt64 struct {
-	v atomic.Value
+	v int64
 }
 
 func (v *atomicInt64) Load() int64 {
-	i, ok := v.v.Load().(int64)
-	if !ok {
-		return 0
-	}
-	return i
+	return atomic.LoadInt64(&v.v)
 }
 
 func (v *atomicInt64) Store(i int64) {
-	v.v.Store(i)
+	atomic.StoreInt64(&v.v, i)
 }
 
 func (v *atomicInt64) String() string {
-	return fmt.Sprint(v.v.Load())
+	return fmt.Sprint(v.Load())
+}
+
+// Increment increments the value and returns the new value.
+func (v *atomicInt64) Increment(delta int64) int64 {
+	return atomic.AddInt64(&v.v, delta)
 }
diff --git a/conn_test.go b/conn_test.go
index 8846979..c948c43 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -37,6 +37,10 @@ import (
 	"nhooyr.io/websocket/wspb"
 )
 
+func init() {
+	rand.Seed(time.Now().UnixNano())
+}
+
 func TestHandshake(t *testing.T) {
 	t.Parallel()
 
@@ -911,10 +915,6 @@ func TestConn(t *testing.T) {
 	}
 }
 
-func init() {
-	rand.Seed(time.Now().UnixNano())
-}
-
 func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) {
 	var conns int64
 	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
diff --git a/frame_test.go b/frame_test.go
index 1a2054c..7d2a571 100644
--- a/frame_test.go
+++ b/frame_test.go
@@ -10,10 +10,15 @@ import (
 	"strconv"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/google/go-cmp/cmp"
 )
 
+func init() {
+	rand.Seed(time.Now().UnixNano())
+}
+
 func randBool() bool {
 	return rand.Intn(1) == 0
 }
diff --git a/handshake.go b/handshake.go
index 0b07808..d1a9fba 100644
--- a/handshake.go
+++ b/handshake.go
@@ -6,13 +6,13 @@ import (
 	"bufio"
 	"bytes"
 	"context"
+	"crypto/rand"
 	"crypto/sha1"
 	"encoding/base64"
 	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
-	"math/rand"
 	"net/http"
 	"net/textproto"
 	"net/url"
@@ -299,7 +299,11 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
 	req.Header.Set("Connection", "Upgrade")
 	req.Header.Set("Upgrade", "websocket")
 	req.Header.Set("Sec-WebSocket-Version", "13")
-	req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
+	secWebSocketKey, err := makeSecWebSocketKey()
+	if err != nil {
+		return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
+	}
+	req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
 	if len(opts.Subprotocols) > 0 {
 		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
 	}
@@ -403,8 +407,11 @@ func returnBufioWriter(bw *bufio.Writer) {
 	bufioWriterPool.Put(bw)
 }
 
-func makeSecWebSocketKey() string {
+func makeSecWebSocketKey() (string, error) {
 	b := make([]byte, 16)
-	rand.Read(b)
-	return base64.StdEncoding.EncodeToString(b)
+	_, err := io.ReadFull(rand.Reader, b)
+	if err != nil {
+		return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
+	}
+	return base64.StdEncoding.EncodeToString(b), nil
 }
diff --git a/handshake_test.go b/handshake_test.go
index a3d9816..cb09353 100644
--- a/handshake_test.go
+++ b/handshake_test.go
@@ -367,14 +367,17 @@ func Test_verifyServerHandshake(t *testing.T) {
 			resp := w.Result()
 
 			r := httptest.NewRequest("GET", "/", nil)
-			key := makeSecWebSocketKey()
+			key, err := makeSecWebSocketKey()
+			if err != nil {
+				t.Fatal(err)
+			}
 			r.Header.Set("Sec-WebSocket-Key", key)
 
 			if resp.Header.Get("Sec-WebSocket-Accept") == "" {
 				resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
 			}
 
-			err := verifyServerResponse(r, resp)
+			err = verifyServerResponse(r, resp)
 			if (err == nil) != tc.success {
 				t.Fatalf("unexpected error: %+v", err)
 			}
-- 
GitLab