Newer
Older
package websocket // import "nhooyr.io/websocket"
import (
"syscall/js"
"nhooyr.io/websocket/internal/wsjs"
)
// Conn provides a wrapper around the browser WebSocket API.
type Conn struct {
ws wsjs.WebSocket
readClosed int64
closeOnce sync.Once
closed chan struct{}
closeErr error
releaseOnClose func()
releaseOnMessage func()
readch chan wsjs.MessageEvent
}
func (c *Conn) close(err error) {
c.closeOnce.Do(func() {
runtime.SetFinalizer(c, nil)
c.closeErr = fmt.Errorf("websocket closed: %w", err)
close(c.closed)
})
}
func (c *Conn) init() {
c.closed = make(chan struct{})
c.readch = make(chan wsjs.MessageEvent, 1)
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
cerr := CloseError{
Code: StatusCode(e.Code),
Reason: e.Reason,
}
c.close(fmt.Errorf("received close frame: %w", cerr))
})
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
c.readch <- e
})
runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
})
}
// Read attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
if atomic.LoadInt64(&c.readClosed) == 1 {
return 0, nil, fmt.Errorf("websocket connection read closed")
}
typ, p, err := c.read(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to read: %w", err)
}
return typ, p, nil
}
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
var me wsjs.MessageEvent
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "read timed out")
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
return 0, nil, ctx.Err()
case me = <-c.readch:
case <-c.closed:
return 0, nil, c.closeErr
}
switch p := me.Data.(type) {
case string:
return MessageText, []byte(p), nil
case []byte:
return MessageBinary, p, nil
default:
panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
}
}
// Write writes a message of the given type to the connection.
// Always non blocking.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write: %w", err)
}
return nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
if c.isClosed() {
return c.closeErr
}
switch typ {
case MessageBinary:
return c.ws.SendBytes(p)
case MessageText:
return c.ws.SendText(string(p))
default:
return fmt.Errorf("unexpected message type: %v", typ)
}
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// Close closes the websocket with the given code and reason.
func (c *Conn) Close(code StatusCode, reason string) error {
if c.isClosed() {
return fmt.Errorf("already closed: %w", c.closeErr)
}
err2 := c.ws.Close(int(code), reason)
if err2 != nil {
err = err2
}
c.close(err)
if !errors.Is(c.closeErr, err) {
return fmt.Errorf("failed to close websocket: %w", err)
}
return nil
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.ws.Protocol
}
// DialOptions represents the options available to pass to Dial.
type DialOptions struct {
// Subprotocols lists the subprotocols to negotiate with the server.
Subprotocols []string
}
// Dial creates a new WebSocket connection to the given url with the given options.
// The passed context bounds the maximum time spent waiting for the connection to open.
// The returned *http.Response is always nil or the zero value. It's only in the signature
// to match the core API.
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
c, resp, err := dial(ctx, url, opts)
if err != nil {
return nil, resp, fmt.Errorf("failed to websocket dial: %w", err)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
}
return c, resp, nil
}
func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
if opts == nil {
opts = &DialOptions{}
}
ws, err := wsjs.New(url, opts.Subprotocols)
if err != nil {
return nil, nil, err
}
c := &Conn{
ws: ws,
}
c.init()
opench := make(chan struct{})
releaseOpen := ws.OnOpen(func(e js.Value) {
close(opench)
})
defer releaseOpen()
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "dial timed out")
return nil, nil, ctx.Err()
case <-opench:
case <-c.closed:
return c, nil, c.closeErr
}
// Have to return a non nil response as the normal API does that.
return c, &http.Response{}, nil
}
func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, error) {
typ, p, err := c.c.Read(ctx)
if err != nil {
return 0, nil, err
}
return typ, bytes.NewReader(p), nil
}
// Only implemented for use by *Conn.CloseRead in netconn.go
func (c *Conn) reader(ctx context.Context) {
c.read(ctx)
}