good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
package websocket_test
import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"os/exec"
"reflect"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/duration"
"github.com/google/go-cmp/cmp"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"nhooyr.io/websocket/wspb"
)
func TestHandshake(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
client func(ctx context.Context, url string) error
server func(w http.ResponseWriter, r *http.Request) error
}{
{
name: "handshake",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"myproto"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"myproto"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
checkHeader := func(h, exp string) {
t.Helper()
value := resp.Header.Get(h)
if exp != value {
t.Errorf("expected different value for header %v: %v", h, cmp.Diff(exp, value))
}
}
checkHeader("Connection", "Upgrade")
checkHeader("Upgrade", "websocket")
checkHeader("Sec-WebSocket-Protocol", "myproto")
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "closeError",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
err = wsjson.Write(r.Context(), c, "hello")
if err != nil {
return err
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
var m string
err = wsjson.Read(ctx, c, &m)
if err != nil {
return err
}
if m != "hello" {
return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m)
}
_, _, err = c.Reader(ctx)
var cerr websocket.CloseError
if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError {
return xerrors.Errorf("unexpected error: %+v", err)
}
return nil
},
},
{
name: "netConn",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
nc := websocket.NetConn(c)
defer nc.Close()
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
for i := 0; i < 3; i++ {
_, err = nc.Write([]byte("hello"))
if err != nil {
return err
}
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
nc := websocket.NetConn(c)
defer nc.Close()
nc.SetReadDeadline(time.Now().Add(time.Second * 15))
read := func() error {
p := make([]byte, len("hello"))
// We do not use io.ReadFull here as it masks EOFs.
// See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024
_, err = nc.Read(p)
if err != nil {
return err
}
if string(p) != "hello" {
return xerrors.Errorf("unexpected payload %q received", string(p))
}
return nil
}
for i := 0; i < 3; i++ {
err = read()
if err != nil {
return err
}
}
// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
err = read()
if err != io.EOF {
return err
}
err = read()
if err != io.EOF {
return err
}
return nil
},
},
{
name: "defaultSubprotocol",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "" {
return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol())
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "" {
return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol())
}
return nil
},
},
{
name: "subprotocol",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"echo", "lar"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "echo" {
return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol())
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"poof", "echo"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "echo" {
return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol())
}
return nil
},
},
{
name: "badOrigin",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err == nil {
c.Close(websocket.StatusInternalError, "")
return xerrors.New("expected error regarding bad origin")
}
return nil
},
client: func(ctx context.Context, u string) error {
h := http.Header{}
h.Set("Origin", "http://unauthorized.com")
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPHeader: h,
})
if err == nil {
c.Close(websocket.StatusInternalError, "")
return xerrors.New("expected handshake failure")
}
return nil
},
},
{
name: "acceptSecureOrigin",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
h := http.Header{}
h.Set("Origin", u)
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPHeader: h,
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
},
{
name: "acceptInsecureOrigin",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
InsecureSkipVerify: true,
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
h := http.Header{}
h.Set("Origin", "https://example.com")
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPHeader: h,
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
},
{
name: "jsonEcho",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
defer cancel()
write := func() error {
v := map[string]interface{}{
"anmol": "wowow",
}
err := wsjson.Write(ctx, c, v)
return err
}
err = write()
if err != nil {
return err
}
err = write()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
read := func() error {
var v interface{}
err := wsjson.Read(ctx, c, &v)
if err != nil {
return err
}
exp := map[string]interface{}{
"anmol": "wowow",
}
if !reflect.DeepEqual(exp, v) {
return xerrors.Errorf("expected %v but got %v", exp, v)
}
return nil
}
err = read()
if err != nil {
return err
}
err = read()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "protobufEcho",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
defer cancel()
write := func() error {
err := wspb.Write(ctx, c, ptypes.DurationProto(100))
return err
}
err = write()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
read := func() error {
var v duration.Duration
err := wspb.Read(ctx, c, &v)
if err != nil {
return err
}
d, err := ptypes.Duration(&v)
if err != nil {
return xerrors.Errorf("failed to convert duration.Duration to time.Duration: %w", err)
}
const exp = time.Duration(100)
if !reflect.DeepEqual(exp, d) {
return xerrors.Errorf("expected %v but got %v", exp, d)
}
return nil
}
err = read()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "cookies",
server: func(w http.ResponseWriter, r *http.Request) error {
cookie, err := r.Cookie("mycookie")
if err != nil {
return xerrors.Errorf("request is missing mycookie: %w", err)
}
if cookie.Value != "myvalue" {
return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
}
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
jar, err := cookiejar.New(nil)
if err != nil {
return xerrors.Errorf("failed to create cookie jar: %w", err)
}
parsedURL, err := url.Parse(u)
if err != nil {
return xerrors.Errorf("failed to parse url: %w", err)
}
parsedURL.Scheme = "http"
jar.SetCookies(parsedURL, []*http.Cookie{
{
Name: "mycookie",
Value: "myvalue",
},
})
hc := &http.Client{
Jar: jar,
}
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPClient: hc,
})
if err != nil {
return err
}
c.Close(websocket.StatusInternalError, "")
return nil
},
},
{
name: "ping",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
errc := make(chan error, 1)
go func() {
_, _, err2 := c.Read(r.Context())
errc <- err2
}()
err = c.Ping(r.Context())
if err != nil {
return err
}
err = c.Write(r.Context(), websocket.MessageText, []byte("hi"))
if err != nil {
return err
}
err = <-errc
var ce websocket.CloseError
if xerrors.As(err, &ce) && ce.Code == websocket.StatusNormalClosure {
return nil
}
return xerrors.Errorf("unexpected error: %w", err)
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
// We read a message from the connection and then keep reading until
// the Ping completes.
done := make(chan struct{})
go func() {
_, _, err := c.Read(ctx)
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return
}
close(done)
c.Read(ctx)
}()
err = c.Ping(ctx)
if err != nil {
return err
}
<-done
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "readLimit",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
_, _, err = c.Read(r.Context())
if err == nil {
return xerrors.Errorf("expected error but got nil")
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
go c.Reader(ctx)
err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769)))
if err != nil {
return err
}
err = c.Ping(ctx)
var ce websocket.CloseError
if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig {
return xerrors.Errorf("unexpected error: %w", err)
}
return nil
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
err := tc.server(w, r)
if err != nil {
t.Errorf("server failed: %+v", err)
return
}
})
defer closeFn()
wsURL := strings.Replace(s.URL, "http", "ws", 1)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
err := tc.client(ctx, wsURL)
if err != nil {
t.Fatalf("client failed: %+v", err)
}
})
}
}
func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) {
var conns int64
s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&conns, 1)
defer atomic.AddInt64(&conns, -1)
fn.ServeHTTP(w, r)
}))
return s, func() {
s.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
for atomic.LoadInt64(&conns) > 0 {
if ctx.Err() != nil {
tb.Fatalf("waiting for server to come down timed out: %v", ctx.Err())
}
}
}
}
// https://github.com/crossbario/autobahn-python/tree/master/wstest
func TestAutobahnServer(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
t.Logf("server handshake failed: %+v", err)
return
}
echoLoop(r.Context(), c)
}))
defer s.Close()
spec := map[string]interface{}{
"outdir": "ci/out/wstestServerReports",
"servers": []interface{}{
map[string]interface{}{
"agent": "main",
"url": strings.Replace(s.URL, "http", "ws", 1),
},
},
"cases": []string{"*"},
// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
// more performance overhead. 7.5.1 is the same.
// 12.* and 13.* as we do not support compression.
"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
}
specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json")
if err != nil {
t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err)
}
defer specFile.Close()
e := json.NewEncoder(specFile)
e.SetIndent("", "\t")
err = e.Encode(spec)
if err != nil {
t.Fatalf("failed to write spec: %v", err)
}
err = specFile.Close()
if err != nil {
t.Fatalf("failed to close file: %v", err)
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
defer cancel()
args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()}
wstest := exec.CommandContext(ctx, "wstest", args...)
out, err := wstest.CombinedOutput()
if err != nil {
t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out)
}
checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
}
func echoLoop(ctx context.Context, c *websocket.Conn) {
defer c.Close(websocket.StatusInternalError, "")
c.SetReadLimit(1 << 40)
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
b := make([]byte, 32768)
echo := func() error {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
w, err := c.Writer(ctx, typ)
if err != nil {
return err
}
_, err = io.CopyBuffer(w, r, b)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return nil
}
for {
err := echo()
if err != nil {
return
}
}
}
func discardLoop(ctx context.Context, c *websocket.Conn) {
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
b := make([]byte, 32768)
echo := func() error {
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
_, err = io.CopyBuffer(ioutil.Discard, r, b)
if err != nil {
return err
}
return nil
}
for {
err := echo()
if err != nil {
return
}
}
}
// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py
func TestAutobahnClient(t *testing.T) {
t.Parallel()
spec := map[string]interface{}{
"url": "ws://localhost:9001",
"outdir": "ci/out/wstestClientReports",
"cases": []string{"*"},
// See TestAutobahnServer for the reasons why we exclude these.
"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
}
specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json")
if err != nil {
t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err)
}
defer specFile.Close()
e := json.NewEncoder(specFile)
e.SetIndent("", "\t")
err = e.Encode(spec)
if err != nil {
t.Fatalf("failed to write spec: %v", err)
}
err = specFile.Close()
if err != nil {
t.Fatalf("failed to close file: %v", err)
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
defer cancel()
args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name()}
if os.Getenv("CI") == "" {
args = append([]string{"--debug"}, args...)
}
wstest := exec.CommandContext(ctx, "wstest", args...)
err = wstest.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
err := wstest.Process.Kill()
if err != nil {
t.Error(err)
}
}()
// Let it come up.
time.Sleep(time.Second * 5)
var cases int
func() {
c, _, err := websocket.Dial(ctx, "ws://localhost:9001/getCaseCount", websocket.DialOptions{})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer c.Close(websocket.StatusInternalError, "")
_, r, err := c.Reader(ctx)
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
cases, err = strconv.Atoi(string(b))
if err != nil {
t.Fatal(err)
}
c.Close(websocket.StatusNormalClosure, "")
}()
for i := 1; i <= cases; i++ {
func() {
ctx, cancel := context.WithTimeout(ctx, time.Second*45)
defer cancel()
c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/runCase?case=%v&agent=main", i), websocket.DialOptions{})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
echoLoop(ctx, c)
}()
}
c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/updateReports?agent=main"), websocket.DialOptions{})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
}
func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := ioutil.ReadFile(path)
if err != nil {
t.Fatalf("failed to read index.json: %v", err)
}
var indexJSON map[string]map[string]struct {
Behavior string `json:"behavior"`
BehaviorClose string `json:"behaviorClose"`
}
err = json.Unmarshal(wstestOut, &indexJSON)
if err != nil {
t.Fatalf("failed to unmarshal index.json: %v", err)
}
var failed bool
for _, tests := range indexJSON {
for test, result := range tests {
switch result.Behavior {
case "OK", "NON-STRICT", "INFORMATIONAL":
default:
failed = true
t.Errorf("test %v failed", test)
}
switch result.BehaviorClose {
case "OK", "INFORMATIONAL":
default:
failed = true
t.Errorf("bad close behaviour for test %v", test)
}
}
}
if failed {
path = strings.Replace(path, ".json", ".html", 1)
if os.Getenv("CI") == "" {
t.Errorf("wstest found failure, please see %q", path)
} else {
t.Errorf("wstest found failure, please run test.sh locally to see %q", path)
}
}
}
func benchConn(b *testing.B, echo, stream bool, size int) {
s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
b.Logf("server handshake failed: %+v", err)
return
}
if echo {
echoLoop(r.Context(), c)
} else {
discardLoop(r.Context(), c)
}
}))
defer closeFn()
wsURL := strings.Replace(s.URL, "http", "ws", 1)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{})
if err != nil {
b.Fatalf("failed to dial: %v", err)
}
defer c.Close(websocket.StatusInternalError, "")
msg := []byte(strings.Repeat("2", size))
readBuf := make([]byte, len(msg))
b.SetBytes(int64(len(msg)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if stream {
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
b.Fatal(err)
}
_, err = w.Write(msg)
if err != nil {
b.Fatal(err)
}
err = w.Close()
if err != nil {
b.Fatal(err)
}
} else {
err = c.Write(ctx, websocket.MessageText, msg)
if err != nil {
b.Fatal(err)
}
}
if echo {
_, r, err := c.Reader(ctx)
if err != nil {
b.Fatal(err)
}
_, err = io.ReadFull(r, readBuf)
if err != nil {
b.Fatal(err)
}
}
}
b.StopTimer()
c.Close(websocket.StatusNormalClosure, "")
}
func BenchmarkConn(b *testing.B) {
sizes := []int{
2,
16,
32,
512,
4096,
16384,
}
b.Run("write", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
b.Run("stream", func(b *testing.B) {
benchConn(b, false, true, size)
})
b.Run("buffer", func(b *testing.B) {
benchConn(b, false, false, size)
})
})
}
})
b.Run("echo", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
benchConn(b, false, false, size)
})
}
})
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"compress/flate"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
return w, nil
}
// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type msgWriter struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context.Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate.Writer
}
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write),
}
}
if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter)
}
mw.flate = true
}
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return c.msgWriter, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = mw.Close()
return n, err
}
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
}
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false
mw.trimWriter.reset()
return nil
}
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
}
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
mw.ensureFlate()
}
}
if mw.flate {
return mw.flateWriter.Write(p)
}
return mw.write(p)
}
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
mw.opcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate: %w", err)
}
}
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.flate && !mw.flateContextTakeover() {
mw.putFlateWriter()
}
mw.mu.unlock()
return nil
}
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
}
mw.writeMu.forceLock()
mw.putFlateWriter()
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
defer func() {
if c.isClosed() && opcode == opClose {
err = nil
}
if err != nil {
if ctx.Err() != nil {
err = ctx.Err()
} else if c.isClosed() {
err = net.ErrClosed
}
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.closeStateMu.Lock()
closeSentErr := c.closeSentErr
c.closeStateMu.Unlock()
if closeSentErr != nil {
return 0, net.ErrClosed
}
select {
case <-c.closed:
return 0, net.ErrClosed
case c.writeTimeout <- ctx:
}
defer func() {
select {
case <-c.closed:
case c.writeTimeout <- context.Background():
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if c.client {
c.writeHeader.masked = true
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
}
c.writeHeader.rsv1 = false
if flate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
if err != nil {
return 0, err
}
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
if opcode == opClose {
c.closeStateMu.Lock()
c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
closeReceived := c.closeReceivedErr != nil
c.closeStateMu.Unlock()
if closeReceived && !c.casClosing() {
c.writeFrameMu.unlock()
_ = c.close()
}
}
return n, nil
}
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !c.writeHeader.masked {
return c.bw.Write(p)
}
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
j := len(p)
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
p = p[j:]
n += j
}
return n, nil
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
bw.WriteByte(0)
bw.Flush()
bw.Reset(w)
return writeBuf
}
func (c *Conn) writeError(code StatusCode, err error) {
c.writeClose(code, err.Error())
}
package websocket // import "github.com/coder/websocket"
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall/js"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/wsjs"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// Conn provides a wrapper around the browser WebSocket API.
type Conn struct {
noCopy noCopy
ws wsjs.WebSocket
// read limit for a message in bytes.
msgReadLimit atomic.Int64
closeReadMu sync.Mutex
closeReadCtx context.Context
closingMu sync.Mutex
closeOnce sync.Once
closed chan struct{}
closeErrOnce sync.Once
closeErr error
closeWasClean bool
releaseOnClose func()
releaseOnError func()
releaseOnMessage func()
readSignal chan struct{}
readBufMu sync.Mutex
readBuf []wsjs.MessageEvent
}
func (c *Conn) close(err error, wasClean bool) {
c.closeOnce.Do(func() {
runtime.SetFinalizer(c, nil)
if !wasClean {
err = fmt.Errorf("unclean connection close: %w", err)
}
c.setCloseErr(err)
c.closeWasClean = wasClean
close(c.closed)
})
}
func (c *Conn) init() {
c.closed = make(chan struct{})
c.readSignal = make(chan struct{}, 1)
c.msgReadLimit.Store(32768)
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
err := CloseError{
Code: StatusCode(e.Code),
Reason: e.Reason,
}
// We do not know if we sent or received this close as
// its possible the browser triggered it without us
// explicitly sending it.
c.close(err, e.WasClean)
c.releaseOnClose()
c.releaseOnError()
c.releaseOnMessage()
})
c.releaseOnError = c.ws.OnError(func(v js.Value) {
c.setCloseErr(errors.New(v.Get("message").String()))
c.closeWithInternal()
})
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
c.readBuf = append(c.readBuf, e)
// Lets the read goroutine know there is definitely something in readBuf.
select {
case c.readSignal <- struct{}{}:
default:
}
})
runtime.SetFinalizer(c, func(c *Conn) {
c.setCloseErr(errors.New("connection garbage collected"))
c.closeWithInternal()
})
}
func (c *Conn) closeWithInternal() {
c.Close(StatusInternalError, "something went wrong")
}
// Read attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
c.closeReadMu.Lock()
closedRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closedRead {
return 0, nil, errors.New("WebSocket connection read closed")
}
typ, p, err := c.read(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to read: %w", err)
}
readLimit := c.msgReadLimit.Load()
if readLimit >= 0 && int64(len(p)) > readLimit {
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
c.Close(StatusMessageTooBig, err.Error())
return 0, nil, err
}
return typ, p, nil
}
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "read timed out")
return 0, nil, ctx.Err()
case <-c.readSignal:
case <-c.closed:
return 0, nil, net.ErrClosed
}
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
me := c.readBuf[0]
// We copy the messages forward and decrease the size
// of the slice to avoid reallocating.
copy(c.readBuf, c.readBuf[1:])
c.readBuf = c.readBuf[:len(c.readBuf)-1]
if len(c.readBuf) > 0 {
// Next time we read, we'll grab the message.
select {
case c.readSignal <- struct{}{}:
default:
}
}
switch p := me.Data.(type) {
case string:
return MessageText, []byte(p), nil
case []byte:
return MessageBinary, p, nil
default:
panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
}
}
// Ping is mocked out for Wasm.
func (c *Conn) Ping(ctx context.Context) error {
return nil
}
// Write writes a message of the given type to the connection.
// Always non blocking.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
err := c.write(ctx, typ, p)
if err != nil {
// Have to ensure the WebSocket is closed after a write error
// to match the Go API. It can only error if the message type
// is unexpected or the passed bytes contain invalid UTF-8 for
// MessageText.
err := fmt.Errorf("failed to write: %w", err)
c.setCloseErr(err)
c.closeWithInternal()
return err
}
return nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
if c.isClosed() {
return net.ErrClosed
}
switch typ {
case MessageBinary:
return c.ws.SendBytes(p)
case MessageText:
return c.ws.SendText(string(p))
default:
return fmt.Errorf("unexpected message type: %v", typ)
}
}
// Close closes the WebSocket with the given code and reason.
// It will wait until the peer responds with a close frame
// or the connection is closed.
// It thus performs the full WebSocket close handshake.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
if err != nil {
return fmt.Errorf("failed to close WebSocket: %w", err)
}
return nil
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
//
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
// a WebSocket without the close handshake.
func (c *Conn) CloseNow() error {
return c.Close(StatusGoingAway, "")
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
c.closingMu.Lock()
defer c.closingMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
ce := fmt.Errorf("sent close: %w", CloseError{
Code: code,
Reason: reason,
})
c.setCloseErr(ce)
err := c.ws.Close(int(code), reason)
if err != nil {
return err
}
<-c.closed
if !c.closeWasClean {
return c.closeErr
}
return nil
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.ws.Subprotocol()
}
// DialOptions represents the options available to pass to Dial.
type DialOptions struct {
// Subprotocols lists the subprotocols to negotiate with the server.
Subprotocols []string
}
// Dial creates a new WebSocket connection to the given url with the given options.
// The passed context bounds the maximum time spent waiting for the connection to open.
// The returned *http.Response is always nil or a mock. It's only in the signature
// to match the core API.
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
c, resp, err := dial(ctx, url, opts)
if err != nil {
return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
}
return c, resp, nil
}
func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
if opts == nil {
opts = &DialOptions{}
}
url = strings.Replace(url, "http://", "ws://", 1)
url = strings.Replace(url, "https://", "wss://", 1)
ws, err := wsjs.New(url, opts.Subprotocols)
if err != nil {
return nil, nil, err
}
c := &Conn{
ws: ws,
}
c.init()
opench := make(chan struct{})
releaseOpen := ws.OnOpen(func(e js.Value) {
close(opench)
})
defer releaseOpen()
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "dial timed out")
return nil, nil, ctx.Err()
case <-opench:
return c, &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}, nil
case <-c.closed:
return nil, nil, net.ErrClosed
}
}
// Reader attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
typ, p, err := c.Read(ctx)
if err != nil {
return 0, nil, err
}
return typ, bytes.NewReader(p), nil
}
// Writer returns a writer to write a WebSocket data message to the connection.
// It buffers the entire message in memory and then sends it when the writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
return &writer{
c: c,
ctx: ctx,
typ: typ,
b: bpool.Get(),
}, nil
}
type writer struct {
closed bool
c *Conn
ctx context.Context
typ MessageType
b *bytes.Buffer
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, errors.New("cannot write to closed writer")
}
n, err := w.b.Write(p)
if err != nil {
return n, fmt.Errorf("failed to write message: %w", err)
}
return n, nil
}
func (w *writer) Close() error {
if w.closed {
return errors.New("cannot close closed writer")
}
w.closed = true
defer bpool.Put(w.b)
err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
if err != nil {
return fmt.Errorf("failed to close writer: %w", err)
}
return nil
}
// CloseRead implements *Conn.CloseRead for wasm.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadMu.Unlock()
go func() {
defer cancel()
defer c.CloseNow()
_, _, err := c.read(ctx)
if err != nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit implements *Conn.SetReadLimit for wasm.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit.Store(n)
}
func (c *Conn) setCloseErr(err error) {
c.closeErrOnce.Do(func() {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
})
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
Subprotocols []string
InsecureSkipVerify bool
OriginPatterns []string
CompressionMode CompressionMode
CompressionThreshold int
}
// Accept is stubbed out for Wasm.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, errors.New("unimplemented")
}
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int
const (
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
// for every message. This applies to both server and client side.
//
// This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used.
//
// The message will only be compressed if greater than 512 bytes.
CompressionNoContextTakeover CompressionMode = iota
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
// This enables reusing the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
//
// If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC.
CompressionContextTakeover
// CompressionDisabled disables the deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
CompressionDisabled
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
package websocket_test
import (
"context"
"net/http"
"os"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
)
func TestWasm(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
assert.Equal(t, "subprotocol", "echo", c.Subprotocol())
assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode)
c.SetReadLimit(65536)
for i := 0; i < 10; i++ {
err = wstest.Echo(ctx, c, 65536)
assert.Success(t, err)
}
err = c.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
}
func TestWasmDialTimeout(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
beforeDial := time.Now()
_, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
assert.Error(t, err)
if time.Since(beforeDial) >= time.Second {
t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial))
}
}
// Package wsjson provides websocket helpers for JSON messages.
package wsjson
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "github.com/coder/websocket/wsjson"
import (
"context"
"encoding/json"
"fmt"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Read reads a json message from c into v.
// It will reuse buffers to avoid allocations.
// Read reads a JSON message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := read(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to read json: %w", err)
}
return nil
return read(ctx, c, v)
}
func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
typ, r, err := c.Reader(ctx)
func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to read JSON message")
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageText {
c.Close(websocket.StatusUnsupportedData, "can only accept text messages")
return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ)
}
b := bpool.Get()
defer func() {
bpool.Put(b)
}()
defer bpool.Put(b)
_, err = b.ReadFrom(r)
if err != nil {
......@@ -45,39 +37,32 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err = json.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
return xerrors.Errorf("failed to unmarshal json: %w", err)
return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
return nil
}
// Write writes the json message v to c.
// It will reuse buffers to avoid allocations.
// Write writes the JSON message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := write(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to write json: %w", err)
}
return nil
return write(ctx, c, v)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) error {
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
return err
}
// We use Encode because it automatically enables buffer reuse without us
// needing to do anything. Though see https://github.com/golang/go/issues/27735
e := json.NewEncoder(w)
err = e.Encode(v)
if err != nil {
return xerrors.Errorf("failed to encode json: %w", err)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to write JSON message")
err = w.Close()
// json.Marshal cannot reuse buffers between calls as it has to return
// a copy of the byte slice but Encoder does as it directly writes to w.
err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) {
err := c.Write(ctx, websocket.MessageText, p)
if err != nil {
return 0, err
}
return len(p), nil
})).Encode(v)
if err != nil {
return err
return fmt.Errorf("failed to marshal JSON: %w", err)
}
return nil
}
package wsjson_test
import (
"encoding/json"
"io"
"strconv"
"testing"
"github.com/coder/websocket/internal/test/xrand"
)
func BenchmarkJSON(b *testing.B) {
sizes := []int{
8,
16,
32,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
}
b.Run("json.Encoder", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.NewEncoder(io.Discard).Encode(msg)
}
})
}
})
b.Run("json.Marshal", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.Marshal(msg)
}
})
}
})
}
// Package wspb provides websocket helpers for protobuf messages.
package wspb
import (
"bytes"
"context"
"sync"
"github.com/golang/protobuf/proto"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
)
// Read reads a protobuf message from c into v.
// It will reuse buffers to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := read(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to read protobuf: %w", err)
}
return nil
}
func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageBinary {
c.Close(websocket.StatusUnsupportedData, "can only accept binary messages")
return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ)
}
b := bpool.Get()
defer func() {
bpool.Put(b)
}()
_, err = b.ReadFrom(r)
if err != nil {
return err
}
err = proto.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf")
return xerrors.Errorf("failed to unmarshal protobuf: %w", err)
}
return nil
}
// Write writes the protobuf message v to c.
// It will reuse buffers to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := write(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to write protobuf: %w", err)
}
return nil
}
var writeBufPool sync.Pool
func write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
b := bpool.Get()
pb := proto.NewBuffer(b.Bytes())
defer func() {
bpool.Put(bytes.NewBuffer(pb.Bytes()))
}()
err := pb.Marshal(v)
if err != nil {
return xerrors.Errorf("failed to marshal protobuf: %w", err)
}
return c.Write(ctx, websocket.MessageBinary, pb.Bytes())
}
package websocket
import (
"encoding/binary"
)
// xor applies the WebSocket masking algorithm to p
// with the given key where the first 3 bits of pos
// are the starting position in the key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the position of the next byte
// to be used for masking in the key. This is so that
// unmasking can be performed without the entire frame.
func fastXOR(key [4]byte, keyPos int, b []byte) int {
// If the payload is greater than or equal to 16 bytes, then it's worth
// masking 8 bytes at a time.
// Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859
if len(b) >= 16 {
// We first create a key that is 8 bytes long
// and is aligned on the position correctly.
var alignedKey [8]byte
for i := range alignedKey {
alignedKey[i] = key[(i+keyPos)&3]
}
k := binary.LittleEndian.Uint64(alignedKey[:])
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
v = binary.LittleEndian.Uint64(b[32:])
binary.LittleEndian.PutUint64(b[32:], v^k)
v = binary.LittleEndian.Uint64(b[40:])
binary.LittleEndian.PutUint64(b[40:], v^k)
v = binary.LittleEndian.Uint64(b[48:])
binary.LittleEndian.PutUint64(b[48:], v^k)
v = binary.LittleEndian.Uint64(b[56:])
binary.LittleEndian.PutUint64(b[56:], v^k)
v = binary.LittleEndian.Uint64(b[64:])
binary.LittleEndian.PutUint64(b[64:], v^k)
v = binary.LittleEndian.Uint64(b[72:])
binary.LittleEndian.PutUint64(b[72:], v^k)
v = binary.LittleEndian.Uint64(b[80:])
binary.LittleEndian.PutUint64(b[80:], v^k)
v = binary.LittleEndian.Uint64(b[88:])
binary.LittleEndian.PutUint64(b[88:], v^k)
v = binary.LittleEndian.Uint64(b[96:])
binary.LittleEndian.PutUint64(b[96:], v^k)
v = binary.LittleEndian.Uint64(b[104:])
binary.LittleEndian.PutUint64(b[104:], v^k)
v = binary.LittleEndian.Uint64(b[112:])
binary.LittleEndian.PutUint64(b[112:], v^k)
v = binary.LittleEndian.Uint64(b[120:])
binary.LittleEndian.PutUint64(b[120:], v^k)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
v = binary.LittleEndian.Uint64(b[32:])
binary.LittleEndian.PutUint64(b[32:], v^k)
v = binary.LittleEndian.Uint64(b[40:])
binary.LittleEndian.PutUint64(b[40:], v^k)
v = binary.LittleEndian.Uint64(b[48:])
binary.LittleEndian.PutUint64(b[48:], v^k)
v = binary.LittleEndian.Uint64(b[56:])
binary.LittleEndian.PutUint64(b[56:], v^k)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
b = b[8:]
}
}
// xor remaining bytes.
for i := range b {
b[i] ^= key[keyPos&3]
keyPos++
}
return keyPos & 3
}
package websocket
import (
"crypto/rand"
"strconv"
"testing"
"github.com/google/go-cmp/cmp"
)
func Test_xor(t *testing.T) {
t.Parallel()
key := [4]byte{0xa, 0xb, 0xc, 0xff}
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
pos := 0
pos = fastXOR(key, pos, p)
if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) {
t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p))
}
if exp := 1; !cmp.Equal(exp, pos) {
t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos))
}
}
func basixXOR(maskKey [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= maskKey[pos&3]
pos++
}
return pos & 3
}
func BenchmarkXOR(b *testing.B) {
sizes := []int{
2,
16,
32,
512,
4096,
16384,
}
fns := []struct {
name string
fn func([4]byte, int, []byte) int
}{
{
"basic",
basixXOR,
},
{
"fast",
fastXOR,
},
}
var maskKey [4]byte
_, err := rand.Read(maskKey[:])
if err != nil {
b.Fatalf("failed to populate mask key: %v", err)
}
for _, size := range sizes {
data := make([]byte, size)
b.Run(strconv.Itoa(size), func(b *testing.B) {
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
fn.fn(maskKey, 0, data)
}
})
}
})
}
}