good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit e476358d authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Improve usage of math/rand versus crypto/rand

math/rand was being used inappropiately and did not have
a init function for every file it was used in.
parent 4f91d7a5
Branches
Tags
No related merge requests found
......@@ -7,24 +7,18 @@ jobs:
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
steps:
- uses: actions/checkout@v1
with:
fetch-depth: 1
- run: ./ci/fmt.sh
lint:
runs-on: ubuntu-latest
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
steps:
- uses: actions/checkout@v1
with:
fetch-depth: 1
- run: ./ci/lint.sh
test:
runs-on: ubuntu-latest
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
steps:
- uses: actions/checkout@v1
with:
fetch-depth: 1
- run: ./ci/test.sh
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
......@@ -33,6 +27,4 @@ jobs:
container: docker://nhooyr/websocket-ci@sha256:b6331f8f64803c8b1bbd2a0ee9e2547317e0de2348bccd9c8dbcc1d88ff5747f
steps:
- uses: actions/checkout@v1
with:
fetch-depth: 1
- run: ./ci/wasm.sh
......@@ -6,6 +6,7 @@ import (
"math/rand"
"reflect"
"strings"
"time"
"github.com/google/go-cmp/cmp"
......@@ -13,6 +14,10 @@ import (
"nhooyr.io/websocket/wsjson"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// https://github.com/google/go-cmp/issues/40#issuecomment-328615283
func cmpDiff(exp, act interface{}) string {
return cmp.Diff(exp, act, deepAllowUnexported(exp, act))
......
......@@ -25,7 +25,7 @@ go install github.com/agnivade/wasmbrowsertest
export WS_ECHO_SERVER_URL
GOOS=js GOARCH=wasm go test -exec=wasmbrowsertest ./...
kill "$wsjstestPID"
kill "$wsjstestPID" || true
if ! wait "$wsjstestPID"; then
echo "--- wsjstest exited unsuccessfully"
echo "output:"
......
......@@ -5,13 +5,12 @@ package websocket
import (
"bufio"
"context"
cryptorand "crypto/rand"
"crypto/rand"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"runtime"
"strconv"
"sync"
......@@ -82,6 +81,7 @@ type Conn struct {
setReadTimeout chan context.Context
setWriteTimeout chan context.Context
pingCounter *atomicInt64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
}
......@@ -100,6 +100,7 @@ func (c *Conn) init() {
c.setReadTimeout = make(chan context.Context)
c.setWriteTimeout = make(chan context.Context)
c.pingCounter = &atomicInt64{}
c.activePings = make(map[string]chan<- struct{})
c.writeHeaderBuf = makeWriteHeaderBuf()
......@@ -669,7 +670,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
c.writeHeader.payloadLength = int64(len(p))
if c.client {
_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
_, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
......@@ -839,10 +840,6 @@ func (c *Conn) writeClose(p []byte, cerr error) error {
return nil
}
func init() {
rand.Seed(time.Now().UnixNano())
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
......@@ -851,10 +848,9 @@ func init() {
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
id := rand.Uint64()
p := strconv.FormatUint(id, 10)
p := c.pingCounter.Increment(1)
err := c.ping(ctx, p)
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
......
......@@ -211,21 +211,22 @@ func (c *Conn) setCloseErr(err error) {
// See https://github.com/nhooyr/websocket/issues/153
type atomicInt64 struct {
v atomic.Value
v int64
}
func (v *atomicInt64) Load() int64 {
i, ok := v.v.Load().(int64)
if !ok {
return 0
}
return i
return atomic.LoadInt64(&v.v)
}
func (v *atomicInt64) Store(i int64) {
v.v.Store(i)
atomic.StoreInt64(&v.v, i)
}
func (v *atomicInt64) String() string {
return fmt.Sprint(v.v.Load())
return fmt.Sprint(v.Load())
}
// Increment increments the value and returns the new value.
func (v *atomicInt64) Increment(delta int64) int64 {
return atomic.AddInt64(&v.v, delta)
}
......@@ -37,6 +37,10 @@ import (
"nhooyr.io/websocket/wspb"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func TestHandshake(t *testing.T) {
t.Parallel()
......@@ -911,10 +915,6 @@ func TestConn(t *testing.T) {
}
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) {
var conns int64
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
......
......@@ -10,10 +10,15 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func randBool() bool {
return rand.Intn(1) == 0
}
......
......@@ -6,13 +6,13 @@ import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/textproto"
"net/url"
......@@ -299,7 +299,11 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
secWebSocketKey, err := makeSecWebSocketKey()
if err != nil {
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
......@@ -403,8 +407,11 @@ func returnBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
func makeSecWebSocketKey() string {
func makeSecWebSocketKey() (string, error) {
b := make([]byte, 16)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
_, err := io.ReadFull(rand.Reader, b)
if err != nil {
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
......@@ -367,14 +367,17 @@ func Test_verifyServerHandshake(t *testing.T) {
resp := w.Result()
r := httptest.NewRequest("GET", "/", nil)
key := makeSecWebSocketKey()
key, err := makeSecWebSocketKey()
if err != nil {
t.Fatal(err)
}
r.Header.Set("Sec-WebSocket-Key", key)
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
}
err := verifyServerResponse(r, resp)
err = verifyServerResponse(r, resp)
if (err == nil) != tc.success {
t.Fatalf("unexpected error: %+v", err)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment