diff --git a/accept.go b/accept.go index 11611d81564cbc2dea90ae3b811b5025feb8ee47..e68a049b32987c4e7eaebedc4fd4a4a48cad1a95 100644 --- a/accept.go +++ b/accept.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -41,6 +43,12 @@ type AcceptOptions struct { } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { + if !r.ProtoAtLeast(1, 1) { + err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/accept_test.go b/accept_test.go index 6602a8d0e0a428e2e9ee4e8f123d2a98276a03bb..44a956a85c4bf7f16a7c6abdfa029de04810626a 100644 --- a/accept_test.go +++ b/accept_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -45,6 +47,7 @@ func Test_verifyClientHandshake(t *testing.T) { testCases := []struct { name string method string + http1 bool h map[string]string success bool }{ @@ -86,6 +89,16 @@ func Test_verifyClientHandshake(t *testing.T) { "Sec-WebSocket-Key": "", }, }, + { + name: "badHTTPVersion", + h: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": "meow123", + }, + http1: true, + }, { name: "success", h: map[string]string{ @@ -106,6 +119,12 @@ func Test_verifyClientHandshake(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(tc.method, "/", nil) + r.ProtoMajor = 1 + r.ProtoMinor = 1 + if tc.http1 { + r.ProtoMinor = 0 + } + for k, v := range tc.h { r.Header.Set(k, v) } diff --git a/ci/wasm.sh b/ci/wasm.sh index 943d380626f6c0ec7d9969aab085236091a93556..9894fca69f29ba5f9239a592cdcb2e91d90d2a64 100755 --- a/ci/wasm.sh +++ b/ci/wasm.sh @@ -6,6 +6,5 @@ cd "$(git rev-parse --show-toplevel)" GOOS=js GOARCH=wasm go vet ./... go install golang.org/x/lint/golint -# Get passing later. -#GOOS=js GOARCH=wasm golint -set_exit_status ./... -GOOS=js GOARCH=wasm go test ./internal/wsjs +GOOS=js GOARCH=wasm golint -set_exit_status ./... +GOOS=js GOARCH=wasm go test ./... diff --git a/dial.go b/dial.go index 51d2af807b5775d7bfbdf32b461eb60dfdc49b9f..79232aac86c0bd4dcbeea2137a01898eaee3bd3e 100644 --- a/dial.go +++ b/dial.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -149,6 +151,10 @@ func verifyServerResponse(r *http.Request, resp *http.Response) error { ) } + if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { + return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + } + return nil } diff --git a/dial_test.go b/dial_test.go index 96537bdbbd824b57e3a4f864bd9ed2c83514787b..083b9bf3ed6bea5984a751f1e2a5fac7954d757a 100644 --- a/dial_test.go +++ b/dial_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -97,6 +99,16 @@ func Test_verifyServerHandshake(t *testing.T) { }, success: false, }, + { + name: "badSecWebSocketProtocol", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Protocol", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, { name: "success", response: func(w http.ResponseWriter) { diff --git a/doc.go b/doc.go index 189952571860dae98e415388c2fb08ea729210ec..cb33c5c9fe36450dc77412e2584cebc18da766fb 100644 --- a/doc.go +++ b/doc.go @@ -1,3 +1,5 @@ +// +build !js + // Package websocket is a minimal and idiomatic implementation of the WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 diff --git a/example_echo_test.go b/example_echo_test.go index aad326756e03a362d4001b41e266a31e6723c447..b1afe8b3552e14b23421b72c89d51ba621fa2b94 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket_test import ( diff --git a/example_test.go b/example_test.go index 36cab2bd6b3698c9bd97f1781e6eb37347437899..2cedddf384e3b37931670ecf5e79cde470471a31 100644 --- a/example_test.go +++ b/example_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket_test import ( diff --git a/export_test.go b/export_test.go index 5a0d1c32482aee5d69f9c6362c46ce3a1bcb1340..32340b56d7ad385e7998652b3a26e4a04b277451 100644 --- a/export_test.go +++ b/export_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/go.mod b/go.mod index 34a7f872d72e550fd0f6945250462f979bf55100..6b3f28ad23d4fe36ec3bcd42c19131e2c1138617 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( golang.org/x/sys v0.0.0-20190919044723-0c1ff786ef13 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gotest.tools/gotestsum v0.3.5 mvdan.cc/sh v2.6.4+incompatible diff --git a/go.sum b/go.sum index 97d6a8358df28bc7ea0ca370efcdd26e3cca2489..de366e52be1a76453921b44b62b4d1e5826f6571 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,7 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72 h1:bw9doJza/SFBEweII/rHQh338oozWyiFsBRHtrflcws= golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/airbrake/gobrake.v2 v2.0.9 h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= diff --git a/header.go b/header.go index 6eb8610f5a65c067a581afab3214dd3ee79a57e4..613b1d1510ffca71cdaf800a0d2cbbbec20c6833 100644 --- a/header.go +++ b/header.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/header_test.go b/header_test.go index 45d0535ae52d1b5aaa8909b736a687ae3632e6c4..5d0fd6a264fcdbacd9f6ef6cd2ce719929f9fb44 100644 --- a/header_test.go +++ b/header_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/internal/echoserver/echoserver.go b/internal/echoserver/echoserver.go new file mode 100644 index 0000000000000000000000000000000000000000..905ede2b5f8d73622b7ad8df4ee05b79b5aa9d35 --- /dev/null +++ b/internal/echoserver/echoserver.go @@ -0,0 +1,11 @@ +package echoserver + +import ( + "net/http" +) + +// EchoServer provides a streaming WebSocket echo server +// for use in tests. +func EchoServer(w http.ResponseWriter, r *http.Request) { + +} diff --git a/internal/wsjs/wsjs.go b/internal/wsjs/wsjs.go index 4adb71ad2072a3d0e27616d4d8906d99d4ee97ff..f83b766cae3e33a8105d7b5359f77291b6df55a6 100644 --- a/internal/wsjs/wsjs.go +++ b/internal/wsjs/wsjs.go @@ -1,11 +1,11 @@ // +build js // Package wsjs implements typed access to the browser javascript WebSocket API. +// // https://developer.mozilla.org/en-US/docs/Web/API/WebSocket package wsjs import ( - "context" "syscall/js" ) @@ -26,9 +26,10 @@ func handleJSError(err *error, onErr func()) { } } -func New(ctx context.Context, url string, protocols []string) (c *WebSocket, err error) { +// New is a wrapper around the javascript WebSocket constructor. +func New(url string, protocols []string) (c WebSocket, err error) { defer handleJSError(&err, func() { - c = nil + c = WebSocket{} }) jsProtocols := make([]interface{}, len(protocols)) @@ -36,7 +37,7 @@ func New(ctx context.Context, url string, protocols []string) (c *WebSocket, err jsProtocols[i] = p } - c = &WebSocket{ + c = WebSocket{ v: js.Global().Get("WebSocket").New(url, jsProtocols), } @@ -49,6 +50,7 @@ func New(ctx context.Context, url string, protocols []string) (c *WebSocket, err return c, nil } +// WebSocket is a wrapper around a javascript WebSocket object. type WebSocket struct { Extensions string Protocol string @@ -57,29 +59,33 @@ type WebSocket struct { v js.Value } -func (c *WebSocket) setBinaryType(typ string) { +func (c WebSocket) setBinaryType(typ string) { c.v.Set("binaryType", string(typ)) } -func (c *WebSocket) BufferedAmount() uint32 { - return uint32(c.v.Get("bufferedAmount").Int()) -} - -func (c *WebSocket) addEventListener(eventType string, fn func(e js.Value)) { - c.v.Call("addEventListener", eventType, js.FuncOf(func(this js.Value, args []js.Value) interface{} { +func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() { + f := js.FuncOf(func(this js.Value, args []js.Value) interface{} { fn(args[0]) return nil - })) + }) + c.v.Call("addEventListener", eventType, f) + + return func() { + c.v.Call("removeEventListener", eventType, f) + f.Release() + } } +// CloseEvent is the type passed to a WebSocket close handler. type CloseEvent struct { Code uint16 Reason string WasClean bool } -func (c *WebSocket) OnClose(fn func(CloseEvent)) { - c.addEventListener("close", func(e js.Value) { +// OnClose registers a function to be called when the WebSocket is closed. +func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) { + return c.addEventListener("close", func(e js.Value) { ce := CloseEvent{ Code: uint16(e.Get("code").Int()), Reason: e.Get("reason").String(), @@ -89,23 +95,29 @@ func (c *WebSocket) OnClose(fn func(CloseEvent)) { }) } -func (c *WebSocket) OnError(fn func(e js.Value)) { - c.addEventListener("error", fn) +// OnError registers a function to be called when there is an error +// with the WebSocket. +func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) { + return c.addEventListener("error", fn) } +// MessageEvent is the type passed to a message handler. type MessageEvent struct { - Data []byte - // There are more types to the interface but we don't use them. + // string or []byte. + Data interface{} + + // There are more fields to the interface but we don't use them. // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent } -func (c *WebSocket) OnMessage(fn func(m MessageEvent)) { - c.addEventListener("message", func(e js.Value) { - var data []byte +// OnMessage registers a function to be called when the websocket receives a message. +func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { + return c.addEventListener("message", func(e js.Value) { + var data interface{} arrayBuffer := e.Get("data") if arrayBuffer.Type() == js.TypeString { - data = []byte(arrayBuffer.String()) + data = arrayBuffer.String() } else { data = extractArrayBuffer(arrayBuffer) } @@ -119,23 +131,29 @@ func (c *WebSocket) OnMessage(fn func(m MessageEvent)) { }) } -func (c *WebSocket) OnOpen(fn func(e js.Value)) { - c.addEventListener("open", fn) +// OnOpen registers a function to be called when the websocket is opened. +func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { + return c.addEventListener("open", fn) } -func (c *WebSocket) Close(code int, reason string) (err error) { +// Close closes the WebSocket with the given code and reason. +func (c WebSocket) Close(code int, reason string) (err error) { defer handleJSError(&err, nil) c.v.Call("close", code, reason) return err } -func (c *WebSocket) SendText(v string) (err error) { +// SendText sends the given string as a text message +// on the WebSocket. +func (c WebSocket) SendText(v string) (err error) { defer handleJSError(&err, nil) c.v.Call("send", v) return err } -func (c *WebSocket) SendBytes(v []byte) (err error) { +// SendBytes sends the given message as a binary message +// on the WebSocket. +func (c WebSocket) SendBytes(v []byte) (err error) { defer handleJSError(&err, nil) c.v.Call("send", uint8Array(v)) return err diff --git a/internal/wsjs/wsjs_test.go b/internal/wsjs/wsjs_test.go deleted file mode 100644 index 4f5f18789845d93665d16a72a8a87d8b4f8a3e02..0000000000000000000000000000000000000000 --- a/internal/wsjs/wsjs_test.go +++ /dev/null @@ -1,26 +0,0 @@ -// +build js - -package wsjs - -import ( - "context" - "syscall/js" - "testing" - "time" -) - -func TestWebSocket(t *testing.T) { - t.Parallel() - - c, err := New(context.Background(), "ws://localhost:8081", nil) - if err != nil { - t.Fatal(err) - } - - c.OnError(func(e js.Value) { - t.Log(js.Global().Get("JSON").Call("stringify", e)) - t.Log(c.v.Get("readyState")) - }) - - time.Sleep(time.Second) -} diff --git a/netconn.go b/netconn.go index 20b99c2a69d531245263235fe36db69f26659a0a..8efdade22d9b172c29dfe6053cb7c7e292b5a32d 100644 --- a/netconn.go +++ b/netconn.go @@ -93,7 +93,7 @@ func (c *netConn) Read(p []byte) (int, error) { } if c.reader == nil { - typ, r, err := c.c.Reader(c.readContext) + typ, r, err := c.netConnReader(c.readContext) if err != nil { var ce CloseError if errors.As(err, &ce) && (ce.Code == StatusNormalClosure) || (ce.Code == StatusGoingAway) { diff --git a/netconn_js.go b/netconn_js.go new file mode 100644 index 0000000000000000000000000000000000000000..5cd15d476283106dda69c9a1b3c5d35cfd59998d --- /dev/null +++ b/netconn_js.go @@ -0,0 +1,17 @@ +// +build js + +package websocket + +import ( + "bytes" + "context" + "io" +) + +func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, error) { + typ, p, err := c.c.Read(ctx) + if err != nil { + return 0, nil, err + } + return typ, bytes.NewReader(p), nil +} diff --git a/netconn_normal.go b/netconn_normal.go new file mode 100644 index 0000000000000000000000000000000000000000..0db551d4533da705addb384107f631906a6b31fa --- /dev/null +++ b/netconn_normal.go @@ -0,0 +1,12 @@ +// +build !js + +package websocket + +import ( + "context" + "io" +) + +func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, error) { + return c.c.Reader(c.readContext) +} diff --git a/opcode.go b/opcode.go index 86f94bd999ce8f9463ad8c291ce91503433f8369..df708aa0baa8e8c4630bab0201bed823ab2dd470 100644 --- a/opcode.go +++ b/opcode.go @@ -3,7 +3,7 @@ package websocket // opcode represents a WebSocket Opcode. type opcode int -//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode +//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode -tags js // opcode constants. const ( diff --git a/opcode_string.go b/opcode_string.go index 740b5e709a6b6ea12b1fba25ba855bc970e0c14d..d7b88961e4765a28c102310ad9ea564f9c042dcc 100644 --- a/opcode_string.go +++ b/opcode_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=opcode"; DO NOT EDIT. +// Code generated by "stringer -type=opcode -tags js"; DO NOT EDIT. package websocket diff --git a/websocket.go b/websocket.go index 9976d0fafebf80beebbb17db2516685181ec9acb..596d89f351fa2a54a6f1358d2671d73bf14cff28 100644 --- a/websocket.go +++ b/websocket.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -438,8 +440,8 @@ func (r *messageReader) eof() bool { func (r *messageReader) Read(p []byte) (int, error) { n, err := r.read(p) if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as xerrors - // isn't used in stdlib. + // Have to return io.EOF directly for now, we cannot wrap as errors.Is + // isn't used widely yet. if errors.Is(err, io.EOF) { return n, io.EOF } diff --git a/websocket_autobahn_python_test.go b/websocket_autobahn_python_test.go index a1e5cccbd6af600652a8c03836de458cabcbfce7..4e8b588e6a567e3f31c8b1270accc3c30b6f342e 100644 --- a/websocket_autobahn_python_test.go +++ b/websocket_autobahn_python_test.go @@ -1,6 +1,7 @@ // This file contains the old autobahn test suite tests that use the -// python binary. The approach is very clunky and slow so new tests +// python binary. The approach is clunky and slow so new tests // have been written in pure Go in websocket_test.go. +// These have been kept for correctness purposes and are occasionally ran. // +build autobahn-python package websocket_test diff --git a/websocket_bench_test.go b/websocket_bench_test.go index 6a54fab21c0dbf3106db21a48e13a3d51d5ba61a..9598e87339af996106355cbeb34dd6bc78eadb66 100644 --- a/websocket_bench_test.go +++ b/websocket_bench_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket_test import ( diff --git a/websocket_js.go b/websocket_js.go new file mode 100644 index 0000000000000000000000000000000000000000..aab104945d2b4d4e9de7bd0bc5f75a0a4b45abe5 --- /dev/null +++ b/websocket_js.go @@ -0,0 +1,211 @@ +package websocket // import "nhooyr.io/websocket" + +import ( + "context" + "errors" + "fmt" + "net/http" + "reflect" + "runtime" + "sync" + "syscall/js" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/wsjs" +) + +// Conn provides a wrapper around the browser WebSocket API. +type Conn struct { + ws wsjs.WebSocket + + closeOnce sync.Once + closed chan struct{} + closeErr error + + releaseOnClose func() + releaseOnMessage func() + + readch chan wsjs.MessageEvent +} + +func (c *Conn) close(err error) { + c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) + + c.closeErr = fmt.Errorf("websocket closed: %w", err) + close(c.closed) + + c.releaseOnClose() + c.releaseOnMessage() + }) +} + +func (c *Conn) init() { + c.closed = make(chan struct{}) + c.readch = make(chan wsjs.MessageEvent, 1) + + c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { + cerr := CloseError{ + Code: StatusCode(e.Code), + Reason: e.Reason, + } + + c.close(fmt.Errorf("received close frame: %w", cerr)) + }) + + c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { + c.readch <- e + }) + + runtime.SetFinalizer(c, func(c *Conn) { + c.ws.Close(int(StatusInternalError), "internal error") + c.close(errors.New("connection garbage collected")) + }) +} + +// Read attempts to read a message from the connection. +// The maximum time spent waiting is bounded by the context. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, p, err := c.read(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to read: %w", err) + } + return typ, p, nil +} + +func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { + var me wsjs.MessageEvent + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case me = <-c.readch: + case <-c.closed: + return 0, nil, c.closeErr + } + + switch p := me.Data.(type) { + case string: + return MessageText, []byte(p), nil + case []byte: + return MessageBinary, p, nil + default: + panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String()) + } +} + +// Write writes a message of the given type to the connection. +// Always non blocking. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + err := c.write(ctx, typ, p) + if err != nil { + return fmt.Errorf("failed to write: %w", err) + } + return nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { + if c.isClosed() { + return c.closeErr + } + switch typ { + case MessageBinary: + return c.ws.SendBytes(p) + case MessageText: + return c.ws.SendText(string(p)) + default: + return fmt.Errorf("unexpected message type: %v", typ) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +// Close closes the websocket with the given code and reason. +func (c *Conn) Close(code StatusCode, reason string) error { + if c.isClosed() { + return fmt.Errorf("already closed: %w", c.closeErr) + } + + cerr := CloseError{ + Code: code, + Reason: reason, + } + + err := fmt.Errorf("sent close frame: %v", cerr) + + err2 := c.ws.Close(int(code), reason) + if err2 != nil { + err = err2 + } + c.close(err) + + if !xerrors.Is(c.closeErr, cerr) { + return xerrors.Errorf("failed to close websocket: %w", err) + } + + return nil +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.ws.Protocol +} + +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string +} + +// Dial creates a new WebSocket connection to the given url with the given options. +// The passed context bounds the maximum time spent waiting for the connection to open. +// The returned *http.Response is always nil or the zero value. It's only in the signature +// to match the core API. +func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + c, resp, err := dial(ctx, url, opts) + if err != nil { + return nil, resp, fmt.Errorf("failed to dial: %w", err) + } + return c, resp, nil +} + +func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + if opts == nil { + opts = &DialOptions{} + } + + ws, err := wsjs.New(url, opts.Subprotocols) + if err != nil { + return nil, nil, err + } + + c := &Conn{ + ws: ws, + } + c.init() + + opench := make(chan struct{}) + releaseOpen := ws.OnOpen(func(e js.Value) { + close(opench) + }) + defer releaseOpen() + + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case <-opench: + case <-c.closed: + return c, nil, c.closeErr + } + + // Have to return a non nil response as the normal API does that. + return c, &http.Response{}, nil +} diff --git a/websocket_js_test.go b/websocket_js_test.go new file mode 100644 index 0000000000000000000000000000000000000000..332c962815ff5c23850a4df4c5631a362e53ea7d --- /dev/null +++ b/websocket_js_test.go @@ -0,0 +1,20 @@ +package websocket_test + +import ( + "context" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestWebSocket(t *testing.T) { + t.Parallel() + + _, _, err := websocket.Dial(context.Background(), "ws://localhost:8081", nil) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second) +} diff --git a/websocket_test.go b/websocket_test.go index 1aa8b201a5a6a9d2f3dabb23207533cafdd9a1d7..eedef845d3ca9b12e065522fd0f5b11865df07d4 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket_test import ( diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 1e63f940ea7948c7267b36495e20c37e4e0f8588..ffdd24ac0c60d395122e8eb871c0875d00de2719 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,3 +1,5 @@ +// +build !js + // Package wsjson provides websocket helpers for JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" diff --git a/wsjson/wsjson_js.go b/wsjson/wsjson_js.go new file mode 100644 index 0000000000000000000000000000000000000000..2e6074ad5270935773d2f0b6a83d04595dbf40cd --- /dev/null +++ b/wsjson/wsjson_js.go @@ -0,0 +1,58 @@ +// +build js + +package wsjson + +import ( + "context" + "encoding/json" + "fmt" + + "nhooyr.io/websocket" +) + +// Read reads a json message from c into v. +func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { + err := read(ctx, c, v) + if err != nil { + return fmt.Errorf("failed to read json: %w", err) + } + return nil +} + +func read(ctx context.Context, c *websocket.Conn, v interface{}) error { + typ, b, err := c.Read(ctx) + if err != nil { + return err + } + + if typ != websocket.MessageText { + c.Close(websocket.StatusUnsupportedData, "can only accept text messages") + return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) + } + + err = json.Unmarshal(b, v) + if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") + return fmt.Errorf("failed to unmarshal json: %w", err) + } + + return nil +} + +// Write writes the json message v to c. +func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { + err := write(ctx, c, v) + if err != nil { + return fmt.Errorf("failed to write json: %w", err) + } + return nil +} + +func write(ctx context.Context, c *websocket.Conn, v interface{}) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + return c.Write(ctx, websocket.MessageBinary, b) +} diff --git a/wspb/wspb.go b/wspb/wspb.go index 8613a08093bdc2a68195f5cea55a1b84284a6a61..b32b0c1ba0bb06a9eb5270b79e50f4a679a53b4c 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -1,3 +1,5 @@ +// +build !js + // Package wspb provides websocket helpers for protobuf messages. package wspb // import "nhooyr.io/websocket/wspb" @@ -5,7 +7,6 @@ import ( "bytes" "context" "fmt" - "sync" "github.com/golang/protobuf/proto" @@ -63,8 +64,6 @@ func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return nil } -var writeBufPool sync.Pool - func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) diff --git a/wspb/wspb_js.go b/wspb/wspb_js.go new file mode 100644 index 0000000000000000000000000000000000000000..6f69eddd0b3947db09efd881c6a131e3f84d554b --- /dev/null +++ b/wspb/wspb_js.go @@ -0,0 +1,67 @@ +// +build js + +package wspb // import "nhooyr.io/websocket/wspb" + +import ( + "bytes" + "context" + "fmt" + + "github.com/golang/protobuf/proto" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/bpool" +) + +// Read reads a protobuf message from c into v. +func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { + err := read(ctx, c, v) + if err != nil { + return fmt.Errorf("failed to read protobuf: %w", err) + } + return nil +} + +func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { + typ, p, err := c.Read(ctx) + if err != nil { + return err + } + + if typ != websocket.MessageBinary { + c.Close(websocket.StatusUnsupportedData, "can only accept binary messages") + return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) + } + + err = proto.Unmarshal(p, v) + if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") + return fmt.Errorf("failed to unmarshal protobuf: %w", err) + } + + return nil +} + +// Write writes the protobuf message v to c. +func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { + err := write(ctx, c, v) + if err != nil { + return fmt.Errorf("failed to write protobuf: %w", err) + } + return nil +} + +func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { + b := bpool.Get() + pb := proto.NewBuffer(b.Bytes()) + defer func() { + bpool.Put(bytes.NewBuffer(pb.Bytes())) + }() + + err := pb.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal protobuf: %w", err) + } + + return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) +} diff --git a/xor.go b/xor.go index 852930df813dd1101584ecd1be90eccee9a29188..f9fe2051fceb16a186fe096b0071d1d7d1a87975 100644 --- a/xor.go +++ b/xor.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/xor_test.go b/xor_test.go index 634af606ac0184ef6ba244a7fb168f2e7403d1ca..70047a9cba2440bf6be246a3a486eb6edbf1d795 100644 --- a/xor_test.go +++ b/xor_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import (