From e476358de061353e5069f490ac09dd3815513e01 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sun, 29 Sep 2019 18:21:31 -0500 Subject: [PATCH] Improve usage of math/rand versus crypto/rand math/rand was being used inappropiately and did not have a init function for every file it was used in. --- .github/workflows/ci.yml | 8 -------- assert_test.go | 5 +++++ ci/wasm.sh | 2 +- conn.go | 16 ++++++---------- conn_common.go | 17 +++++++++-------- conn_test.go | 8 ++++---- frame_test.go | 5 +++++ handshake.go | 17 ++++++++++++----- handshake_test.go | 7 +++++-- 9 files changed, 47 insertions(+), 38 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b07c54b..a53a469 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,24 +7,18 @@ jobs: container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/fmt.sh lint: runs-on: ubuntu-latest container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/lint.sh test: runs-on: ubuntu-latest container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/test.sh env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} @@ -33,6 +27,4 @@ jobs: container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f steps: - uses: actions/checkout@v1 - with: - fetch-depth: 1 - run: ./ci/wasm.sh diff --git a/assert_test.go b/assert_test.go index 8970c54..e67ed53 100644 --- a/assert_test.go +++ b/assert_test.go @@ -6,6 +6,7 @@ import ( "math/rand" "reflect" "strings" + "time" "github.com/google/go-cmp/cmp" @@ -13,6 +14,10 @@ import ( "nhooyr.io/websocket/wsjson" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + // https://github.com/google/go-cmp/issues/40#issuecomment-328615283 func cmpDiff(exp, act interface{}) string { return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) diff --git a/ci/wasm.sh b/ci/wasm.sh index 134b60b..c1d9a40 100755 --- a/ci/wasm.sh +++ b/ci/wasm.sh @@ -25,7 +25,7 @@ go install github.com/agnivade/wasmbrowsertest export WS_ECHO_SERVER_URL GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./... -kill "$wsjstestPID" +kill "$wsjstestPID" || true if ! wait "$wsjstestPID"; then echo "--- wsjstest exited unsuccessfully" echo "output:" diff --git a/conn.go b/conn.go index 37c4cac..d74b875 100644 --- a/conn.go +++ b/conn.go @@ -5,13 +5,12 @@ package websocket import ( "bufio" "context" - cryptorand "crypto/rand" + "crypto/rand" "errors" "fmt" "io" "io/ioutil" "log" - "math/rand" "runtime" "strconv" "sync" @@ -82,6 +81,7 @@ type Conn struct { setReadTimeout chan context.Context setWriteTimeout chan context.Context + pingCounter *atomicInt64 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } @@ -100,6 +100,7 @@ func (c *Conn) init() { c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) + c.pingCounter = &atomicInt64{} c.activePings = make(map[string]chan<- struct{}) c.writeHeaderBuf = makeWriteHeaderBuf() @@ -669,7 +670,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte c.writeHeader.payloadLength = int64(len(p)) if c.client { - _, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:]) + _, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:]) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } @@ -839,10 +840,6 @@ func (c *Conn) writeClose(p []byte, cerr error) error { return nil } -func init() { - rand.Seed(time.Now().UnixNano()) -} - // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // Ping must be called concurrently with Reader as it does @@ -851,10 +848,9 @@ func init() { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - id := rand.Uint64() - p := strconv.FormatUint(id, 10) + p := c.pingCounter.Increment(1) - err := c.ping(ctx, p) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) } diff --git a/conn_common.go b/conn_common.go index e7a0103..8233e4a 100644 --- a/conn_common.go +++ b/conn_common.go @@ -211,21 +211,22 @@ func (c *Conn) setCloseErr(err error) { // See https://github.com/nhooyr/websocket/issues/153 type atomicInt64 struct { - v atomic.Value + v int64 } func (v *atomicInt64) Load() int64 { - i, ok := v.v.Load().(int64) - if !ok { - return 0 - } - return i + return atomic.LoadInt64(&v.v) } func (v *atomicInt64) Store(i int64) { - v.v.Store(i) + atomic.StoreInt64(&v.v, i) } func (v *atomicInt64) String() string { - return fmt.Sprint(v.v.Load()) + return fmt.Sprint(v.Load()) +} + +// Increment increments the value and returns the new value. +func (v *atomicInt64) Increment(delta int64) int64 { + return atomic.AddInt64(&v.v, delta) } diff --git a/conn_test.go b/conn_test.go index 8846979..c948c43 100644 --- a/conn_test.go +++ b/conn_test.go @@ -37,6 +37,10 @@ import ( "nhooyr.io/websocket/wspb" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + func TestHandshake(t *testing.T) { t.Parallel() @@ -911,10 +915,6 @@ func TestConn(t *testing.T) { } } -func init() { - rand.Seed(time.Now().UnixNano()) -} - func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { var conns int64 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/frame_test.go b/frame_test.go index 1a2054c..7d2a571 100644 --- a/frame_test.go +++ b/frame_test.go @@ -10,10 +10,15 @@ import ( "strconv" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + func randBool() bool { return rand.Intn(1) == 0 } diff --git a/handshake.go b/handshake.go index 0b07808..d1a9fba 100644 --- a/handshake.go +++ b/handshake.go @@ -6,13 +6,13 @@ import ( "bufio" "bytes" "context" + "crypto/rand" "crypto/sha1" "encoding/base64" "errors" "fmt" "io" "io/ioutil" - "math/rand" "net/http" "net/textproto" "net/url" @@ -299,7 +299,11 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey()) + secWebSocketKey, err := makeSecWebSocketKey() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } @@ -403,8 +407,11 @@ func returnBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } -func makeSecWebSocketKey() string { +func makeSecWebSocketKey() (string, error) { b := make([]byte, 16) - rand.Read(b) - return base64.StdEncoding.EncodeToString(b) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil } diff --git a/handshake_test.go b/handshake_test.go index a3d9816..cb09353 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -367,14 +367,17 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key := makeSecWebSocketKey() + key, err := makeSecWebSocketKey() + if err != nil { + t.Fatal(err) + } r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - err := verifyServerResponse(r, resp) + err = verifyServerResponse(r, resp) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) } -- GitLab