good morning!!!!

Skip to content
Snippets Groups Projects
conn_test.go 12.1 KiB
Newer Older
// +build !js

Anmol Sethi's avatar
Anmol Sethi committed
package websocket_test

import (
Anmol Sethi's avatar
Anmol Sethi committed
	"bytes"
Anmol Sethi's avatar
Anmol Sethi committed
	"context"
Anmol Sethi's avatar
Anmol Sethi committed
	"fmt"
Anmol Sethi's avatar
Anmol Sethi committed
	"io"
	"io/ioutil"
Anmol Sethi's avatar
Anmol Sethi committed
	"net/http"
	"net/http/httptest"
	"os"
	"os/exec"
	"strings"
Anmol Sethi's avatar
Anmol Sethi committed
	"testing"
	"time"

Anmol Sethi's avatar
Anmol Sethi committed
	"github.com/gin-gonic/gin"
Anmol Sethi's avatar
Anmol Sethi committed
	"github.com/golang/protobuf/ptypes"
	"github.com/golang/protobuf/ptypes/duration"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/errd"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/test/assert"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/test/wstest"
	"nhooyr.io/websocket/internal/test/xrand"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/internal/xsync"
Anmol Sethi's avatar
Anmol Sethi committed
	"nhooyr.io/websocket/wsjson"
	"nhooyr.io/websocket/wspb"
Anmol Sethi's avatar
Anmol Sethi committed
func TestConn(t *testing.T) {
	t.Parallel()

Anmol Sethi's avatar
Anmol Sethi committed
	t.Run("fuzzData", func(t *testing.T) {
		compressionMode := func() websocket.CompressionMode {
			return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1))
Anmol Sethi's avatar
Anmol Sethi committed
		for i := 0; i < 5; i++ {
			t.Run("", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
				tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
					CompressionMode:      compressionMode(),
					CompressionThreshold: xrand.Int(9999),
				}, &websocket.AcceptOptions{
					CompressionMode:      compressionMode(),
					CompressionThreshold: xrand.Int(9999),

				tt.goEchoLoop(c2)
Anmol Sethi's avatar
Anmol Sethi committed
				c1.SetReadLimit(131072)
Anmol Sethi's avatar
Anmol Sethi committed
				for i := 0; i < 5; i++ {
					err := wstest.Echo(tt.ctx, c1, 131072)
Anmol Sethi's avatar
Anmol Sethi committed
					assert.Success(t, err)
				err := c1.Close(websocket.StatusNormalClosure, "")
Anmol Sethi's avatar
Anmol Sethi committed
				assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed

	t.Run("badClose", func(t *testing.T) {
		tt, c1, c2 := newConnTest(t, nil, nil)

		c2.CloseRead(tt.ctx)
Anmol Sethi's avatar
Anmol Sethi committed

		err := c1.Close(-1, "")
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
Anmol Sethi's avatar
Anmol Sethi committed
	})

	t.Run("ping", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		c1.CloseRead(tt.ctx)
		c2.CloseRead(tt.ctx)
Anmol Sethi's avatar
Anmol Sethi committed

		for i := 0; i < 10; i++ {
			err := c1.Ping(tt.ctx)
Anmol Sethi's avatar
Anmol Sethi committed
			assert.Success(t, err)
		err := c1.Close(websocket.StatusNormalClosure, "")
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
	})

	t.Run("badPing", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		c2.CloseRead(tt.ctx)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
		defer cancel()

		err := c1.Ping(ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, "failed to wait for pong")
Anmol Sethi's avatar
Anmol Sethi committed
	})

	t.Run("concurrentWrite", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
		tt.goDiscardLoop(c2)
Anmol Sethi's avatar
Anmol Sethi committed

		msg := xrand.Bytes(xrand.Int(9999))
		const count = 100
		errs := make(chan error, count)

		for i := 0; i < count; i++ {
			go func() {
Anmol Sethi's avatar
Anmol Sethi committed
				select {
				case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
				case <-tt.ctx.Done():
					return
				}
Anmol Sethi's avatar
Anmol Sethi committed
			}()
		}

		for i := 0; i < count; i++ {
Anmol Sethi's avatar
Anmol Sethi committed
			select {
			case err := <-errs:
				assert.Success(t, err)
			case <-tt.ctx.Done():
				t.Fatal(tt.ctx.Err())
			}
		err := c1.Close(websocket.StatusNormalClosure, "")
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
	})
Anmol Sethi's avatar
Anmol Sethi committed

	t.Run("concurrentWriteError", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, _ := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		_, err := c1.Writer(tt.ctx, websocket.MessageText)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed

		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
		defer cancel()

		err = c1.Write(ctx, websocket.MessageText, []byte("x"))
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Equal(t, "write error", context.DeadlineExceeded, err)
Anmol Sethi's avatar
Anmol Sethi committed
	})

	t.Run("netConn", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
Anmol Sethi's avatar
Anmol Sethi committed

		// Does not give any confidence but at least ensures no crashes.
		d, _ := tt.ctx.Deadline()
Anmol Sethi's avatar
Anmol Sethi committed
		n1.SetDeadline(d)
		n1.SetDeadline(time.Time{})

Anmol Sethi's avatar
Anmol Sethi committed
		assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
		assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String())
		assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network())
Anmol Sethi's avatar
Anmol Sethi committed

		errs := xsync.Go(func() error {
			_, err := n2.Write([]byte("hello"))
			if err != nil {
				return err
			}
			return n2.Close()
		})

		b, err := ioutil.ReadAll(n1)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed

		_, err = n1.Read(nil)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Equal(t, "read error", err, io.EOF)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case err := <-errs:
			assert.Success(t, err)
		case <-tt.ctx.Done():
			t.Fatal(tt.ctx.Err())
		}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		assert.Equal(t, "read msg", []byte("hello"), b)
Anmol Sethi's avatar
Anmol Sethi committed
	t.Run("netConn/BadMsg", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
Anmol Sethi's avatar
Anmol Sethi committed

		c2.CloseRead(tt.ctx)
Anmol Sethi's avatar
Anmol Sethi committed
		errs := xsync.Go(func() error {
			_, err := n2.Write([]byte("hello"))
		_, err := ioutil.ReadAll(n1)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case err := <-errs:
			assert.Success(t, err)
		case <-tt.ctx.Done():
			t.Fatal(tt.ctx.Err())
		}
Anmol Sethi's avatar
Anmol Sethi committed
	})
Anmol Sethi's avatar
Anmol Sethi committed

	t.Run("netConn/readLimit", func(t *testing.T) {
		tt, c1, c2 := newConnTest(t, nil, nil)

		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)

Anmol Sethi's avatar
Anmol Sethi committed
		s := strings.Repeat("papa", 1<<20)
		errs := xsync.Go(func() error {
			_, err := n2.Write([]byte(s))
			if err != nil {
				return err
			}
			return n2.Close()
		})

		b, err := ioutil.ReadAll(n1)
		assert.Success(t, err)

		_, err = n1.Read(nil)
		assert.Equal(t, "read error", err, io.EOF)

		select {
		case err := <-errs:
			assert.Success(t, err)
		case <-tt.ctx.Done():
			t.Fatal(tt.ctx.Err())
		}

		assert.Equal(t, "read msg", s, string(b))
	})

Anmol Sethi's avatar
Anmol Sethi committed
	t.Run("wsjson", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		tt.goEchoLoop(c2)
Anmol Sethi's avatar
Anmol Sethi committed

		c1.SetReadLimit(1 << 30)
Anmol Sethi's avatar
Anmol Sethi committed

		exp := xrand.String(xrand.Int(131072))
Anmol Sethi's avatar
Anmol Sethi committed

		werr := xsync.Go(func() error {
			return wsjson.Write(tt.ctx, c1, exp)
Anmol Sethi's avatar
Anmol Sethi committed
		})
Anmol Sethi's avatar
Anmol Sethi committed

		var act interface{}
		err := wsjson.Read(tt.ctx, c1, &act)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
		assert.Equal(t, "read msg", exp, act)
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		select {
		case err := <-werr:
			assert.Success(t, err)
		case <-tt.ctx.Done():
			t.Fatal(tt.ctx.Err())
		}
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
		err = c1.Close(websocket.StatusNormalClosure, "")
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
	})

	t.Run("wspb", func(t *testing.T) {
Anmol Sethi's avatar
Anmol Sethi committed
		tt, c1, c2 := newConnTest(t, nil, nil)
Anmol Sethi's avatar
Anmol Sethi committed

		tt.goEchoLoop(c2)
Anmol Sethi's avatar
Anmol Sethi committed

		exp := ptypes.DurationProto(100)
		err := wspb.Write(tt.ctx, c1, exp)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed

		act := &duration.Duration{}
		err = wspb.Read(tt.ctx, c1, act)
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
		assert.Equal(t, "read msg", exp, act)
Anmol Sethi's avatar
Anmol Sethi committed

		err = c1.Close(websocket.StatusNormalClosure, "")
Anmol Sethi's avatar
Anmol Sethi committed
		assert.Success(t, err)
Anmol Sethi's avatar
Anmol Sethi committed
	})
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
func TestWasm(t *testing.T) {
	t.Parallel()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Anmol Sethi's avatar
Anmol Sethi committed
		err := echoServer(w, r, &websocket.AcceptOptions{
			Subprotocols:       []string{"echo"},
			InsecureSkipVerify: true,
Anmol Sethi's avatar
Anmol Sethi committed
		})
Anmol Sethi's avatar
Anmol Sethi committed
		if err != nil {
Anmol Sethi's avatar
Anmol Sethi committed
			t.Error(err)
Anmol Sethi's avatar
Anmol Sethi committed
		}
Anmol Sethi's avatar
Anmol Sethi committed
	}))
Anmol Sethi's avatar
Anmol Sethi committed
	defer s.Close()
Anmol Sethi's avatar
Anmol Sethi committed

Anmol Sethi's avatar
Anmol Sethi committed
	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
Anmol Sethi's avatar
Anmol Sethi committed
	defer cancel()
Anmol Sethi's avatar
Anmol Sethi committed

	cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
	cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
Anmol Sethi's avatar
Anmol Sethi committed

	b, err := cmd.CombinedOutput()
	if err != nil {
		t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
Anmol Sethi's avatar
Anmol Sethi committed
	}
}
Anmol Sethi's avatar
Anmol Sethi committed

func assertCloseStatus(exp websocket.StatusCode, err error) error {
	if websocket.CloseStatus(err) == -1 {
		return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	if websocket.CloseStatus(err) != exp {
		return fmt.Errorf("expected close status %v but got %v", exp, err)
Anmol Sethi's avatar
Anmol Sethi committed
	}
	return nil
}
Anmol Sethi's avatar
Anmol Sethi committed
type connTest struct {
Anmol Sethi's avatar
Anmol Sethi committed
	t   testing.TB
	ctx context.Context
}

Anmol Sethi's avatar
Anmol Sethi committed
func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
	if t, ok := t.(*testing.T); ok {
		t.Parallel()
	}
Anmol Sethi's avatar
Anmol Sethi committed
	t.Helper()
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
Anmol Sethi's avatar
Anmol Sethi committed
	tt = &connTest{t: t, ctx: ctx}
Anmol Sethi's avatar
Anmol Sethi committed
	c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
	if xrand.Bool() {
		c1, c2 = c2, c1
	}
	t.Cleanup(func() {
		// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
		// blocking the test shutting down.
		go c2.Close(websocket.StatusInternalError, "")
		go c1.Close(websocket.StatusInternalError, "")
Anmol Sethi's avatar
Anmol Sethi committed
	})

	return tt, c1, c2
Anmol Sethi's avatar
Anmol Sethi committed
func (tt *connTest) goEchoLoop(c *websocket.Conn) {
	ctx, cancel := context.WithCancel(tt.ctx)

	echoLoopErr := xsync.Go(func() error {
		err := wstest.EchoLoop(ctx, c)
		return assertCloseStatus(websocket.StatusNormalClosure, err)
	})
	tt.t.Cleanup(func() {
		cancel()
		err := <-echoLoopErr
		if err != nil {
			tt.t.Errorf("echo loop error: %v", err)
		}
	})
}

Anmol Sethi's avatar
Anmol Sethi committed
func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
	ctx, cancel := context.WithCancel(tt.ctx)

	discardLoopErr := xsync.Go(func() error {
		defer c.Close(websocket.StatusInternalError, "")

		for {
			_, _, err := c.Read(ctx)
			if err != nil {
				return assertCloseStatus(websocket.StatusNormalClosure, err)
	tt.t.Cleanup(func() {
		cancel()
		err := <-discardLoopErr
		if err != nil {
			tt.t.Errorf("discard loop error: %v", err)
		}
	})
}
Anmol Sethi's avatar
Anmol Sethi committed

func BenchmarkConn(b *testing.B) {
	var benchCases = []struct {
		name string
		mode websocket.CompressionMode
	}{
		{
			name: "disabledCompress",
Anmol Sethi's avatar
Anmol Sethi committed
			mode: websocket.CompressionDisabled,
		},
		{
			name: "compressContextTakeover",
Anmol Sethi's avatar
Anmol Sethi committed
			mode: websocket.CompressionContextTakeover,
		},
		{
			name: "compressNoContext",
Anmol Sethi's avatar
Anmol Sethi committed
			mode: websocket.CompressionNoContextTakeover,
		},
	}
	for _, bc := range benchCases {
		b.Run(bc.name, func(b *testing.B) {
			bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
				CompressionMode: bc.mode,
			}, &websocket.AcceptOptions{
				CompressionMode: bc.mode,
Anmol Sethi's avatar
Anmol Sethi committed

			bb.goEchoLoop(c2)

			bytesWritten := c1.RecordBytesWritten()
			bytesRead := c1.RecordBytesRead()

			msg := []byte(strings.Repeat("1234", 128))
			readBuf := make([]byte, len(msg))
			writes := make(chan struct{})
Anmol Sethi's avatar
Anmol Sethi committed
			defer close(writes)
			werrs := make(chan error)

			go func() {
				for range writes {
Anmol Sethi's avatar
Anmol Sethi committed
					select {
					case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
					case <-bb.ctx.Done():
						return
					}
Anmol Sethi's avatar
Anmol Sethi committed
				}
			}()
			b.SetBytes(int64(len(msg)))
Anmol Sethi's avatar
Anmol Sethi committed
			b.ReportAllocs()
			b.ResetTimer()
			for i := 0; i < b.N; i++ {
Anmol Sethi's avatar
Anmol Sethi committed
				select {
				case writes <- struct{}{}:
				case <-bb.ctx.Done():
					b.Fatal(bb.ctx.Err())
				}
Anmol Sethi's avatar
Anmol Sethi committed

				typ, r, err := c1.Reader(bb.ctx)
				if err != nil {
					b.Fatal(err)
				}
				if websocket.MessageText != typ {
					assert.Equal(b, "data type", websocket.MessageText, typ)
Anmol Sethi's avatar
Anmol Sethi committed
				}

				_, err = io.ReadFull(r, readBuf)
				if err != nil {
					b.Fatal(err)
				}

				n2, err := r.Read(readBuf)
				if err != io.EOF {
					assert.Equal(b, "read err", io.EOF, err)
				}
				if n2 != 0 {
					assert.Equal(b, "n2", 0, n2)
				}

				if !bytes.Equal(msg, readBuf) {
					assert.Equal(b, "msg", msg, readBuf)
Anmol Sethi's avatar
Anmol Sethi committed
				select {
				case err = <-werrs:
				case <-bb.ctx.Done():
					b.Fatal(bb.ctx.Err())
				}
Anmol Sethi's avatar
Anmol Sethi committed
				if err != nil {
					b.Fatal(err)
				}
			}
			b.StopTimer()

			b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
			b.ReportMetric(float64(*bytesRead/b.N), "read/op")

Anmol Sethi's avatar
Anmol Sethi committed
			err := c1.Close(websocket.StatusNormalClosure, "")
			assert.Success(b, err)
		})
	}
}
Anmol Sethi's avatar
Anmol Sethi committed
func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) {
	defer errd.Wrap(&err, "echo server failed")

	c, err := websocket.Accept(w, r, opts)
	if err != nil {
		return err
	}
	defer c.Close(websocket.StatusInternalError, "")

	err = wstest.EchoLoop(r.Context(), c)
	return assertCloseStatus(websocket.StatusNormalClosure, err)
}

func TestGin(t *testing.T) {
	t.Parallel()

Anmol Sethi's avatar
Anmol Sethi committed
	gin.SetMode(gin.ReleaseMode)
	r := gin.New()
	r.GET("/", func(ginCtx *gin.Context) {
		err := echoServer(ginCtx.Writer, ginCtx.Request, nil)
		if err != nil {
			t.Error(err)
		}
	})

	s := httptest.NewServer(r)
	defer s.Close()

	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
	defer cancel()

	c, _, err := websocket.Dial(ctx, s.URL, nil)
	assert.Success(t, err)
	defer c.Close(websocket.StatusInternalError, "")

	err = wsjson.Write(ctx, c, "hello")
	assert.Success(t, err)

	var v interface{}
	err = wsjson.Read(ctx, c, &v)
	assert.Success(t, err)
	assert.Equal(t, "read msg", "hello", v)

	err = c.Close(websocket.StatusNormalClosure, "")
	assert.Success(t, err)