diff --git a/accept.go b/accept.go index 2dabdae3ccf21d5bec42890161728dd9b65e8db0..3120690a54b88cdd2519e1d599f519a52fb5ee3e 100644 --- a/accept.go +++ b/accept.go @@ -29,20 +29,26 @@ func AcceptSubprotocols(protocols ...string) AcceptOption { return acceptSubprotocols(protocols) } -type acceptOrigins []string +type acceptInsecureOrigin struct{} -func (o acceptOrigins) acceptOption() {} +func (o acceptInsecureOrigin) acceptOption() {} -// AcceptOrigins lists the origins that Accept will accept. -// Accept will always accept r.Host as the origin. Use this -// option when you want to accept an origin with a different domain -// than the one the WebSocket server is running on. +// AcceptInsecureOrigin disables Accept's origin verification +// behaviour. By default Accept only allows the handshake to +// succeed if the javascript that is initiating the handshake +// is on the same domain as the server. This is to prevent CSRF +// when secure data is stored in cookies. // -// Use this option with caution to avoid exposing your WebSocket -// server to a CSRF attack. // See https://stackoverflow.com/a/37837709/4283659 -func AcceptOrigins(origins ...string) AcceptOption { - return acceptOrigins(origins) +// +// Use this if you want a WebSocket server any javascript can +// connect to or you want to perform Origin verification yourself +// and allow some whitelist of domains. +// +// Ensure you understand exactly what the above means before you use +// this option in conjugation with cookies containing secure data. +func AcceptInsecureOrigin() AcceptOption { + return acceptInsecureOrigin{} } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { @@ -86,11 +92,11 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept uses w to write the handshake response so the timeouts on the http.Server apply. func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) { var subprotocols []string - origins := []string{r.Host} + verifyOrigin := true for _, opt := range opts { switch opt := opt.(type) { - case acceptOrigins: - origins = []string(opt) + case acceptInsecureOrigin: + verifyOrigin = false case acceptSubprotocols: subprotocols = []string(opt) } @@ -101,12 +107,12 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn return nil, err } - origins = append(origins, r.Host) - - err = authenticateOrigin(r, origins) - if err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return nil, err + if verifyOrigin { + err = authenticateOrigin(r) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } } hj, ok := w.(http.Hijacker) @@ -172,7 +178,7 @@ func handleKey(w http.ResponseWriter, r *http.Request) { w.Header().Set("Sec-WebSocket-Accept", responseKey) } -func authenticateOrigin(r *http.Request, origins []string) error { +func authenticateOrigin(r *http.Request) error { origin := r.Header.Get("Origin") if origin == "" { return nil @@ -181,10 +187,8 @@ func authenticateOrigin(r *http.Request, origins []string) error { if err != nil { return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err) } - for _, o := range origins { - if strings.EqualFold(u.Host, o) { - return nil - } + if strings.EqualFold(u.Host, r.Host) { + return nil } - return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin")) + return xerrors.Errorf("request origin %q is not authorized", origin) } diff --git a/accept_test.go b/accept_test.go index 4b5214dde7a45045e130a2c472545ba8390536f4..6f5c3fb9e9f3896330ab692d2bee4808a1f5a4b2 100644 --- a/accept_test.go +++ b/accept_test.go @@ -140,37 +140,39 @@ func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { - name string - origin string - authorizedOrigins []string - success bool + name string + origin string + host string + success bool }{ { name: "none", success: true, + host: "example.com", }, { name: "invalid", origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}", + host: "example.com", success: false, }, { - name: "unauthorized", - origin: "https://example.com", - authorizedOrigins: []string{"example1.com"}, - success: false, + name: "unauthorized", + origin: "https://example.com", + host: "example1.com", + success: false, }, { - name: "authorized", - origin: "https://example.com", - authorizedOrigins: []string{"example.com"}, - success: true, + name: "authorized", + origin: "https://example.com", + host: "example.com", + success: true, }, { - name: "authorizedCaseInsensitive", - origin: "https://examplE.com", - authorizedOrigins: []string{"example.com"}, - success: true, + name: "authorizedCaseInsensitive", + origin: "https://examplE.com", + host: "example.com", + success: true, }, } @@ -179,10 +181,10 @@ func Test_authenticateOrigin(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) - err := authenticateOrigin(r, tc.authorizedOrigins) + err := authenticateOrigin(r) if (err == nil) != tc.success { t.Fatalf("unexpected error value: %+v", err) } diff --git a/datatype.go b/datatype.go deleted file mode 100644 index a1d8d5751c0a4d0b911547b36955bda169ce8008..0000000000000000000000000000000000000000 --- a/datatype.go +++ /dev/null @@ -1,12 +0,0 @@ -package websocket - -// DataType represents the Opcode of a WebSocket data frame. -type DataType int - -//go:generate go run golang.org/x/tools/cmd/stringer -type=DataType - -// DataType constants. -const ( - DataText DataType = DataType(opText) - DataBinary DataType = DataType(opBinary) -) diff --git a/datatype_string.go b/datatype_string.go deleted file mode 100644 index 1b4aaba5fdc911c3d17956a0983bdfbe645478dc..0000000000000000000000000000000000000000 --- a/datatype_string.go +++ /dev/null @@ -1,25 +0,0 @@ -// Code generated by "stringer -type=DataType"; DO NOT EDIT. - -package websocket - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[DataText-1] - _ = x[DataBinary-2] -} - -const _DataType_name = "DataTextDataBinary" - -var _DataType_index = [...]uint8{0, 8, 18} - -func (i DataType) String() string { - i -= 1 - if i < 0 || i >= DataType(len(_DataType_index)-1) { - return "DataType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _DataType_name[_DataType_index[i]:_DataType_index[i+1]] -} diff --git a/example_test.go b/example_test.go index bee8e9277bfd0e2bf0b9028e37ba14165a6c01f3..702239b2afa5530610dbe85651f737fc50dc83e9 100644 --- a/example_test.go +++ b/example_test.go @@ -14,13 +14,18 @@ import ( func ExampleAccept_echo() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("echo")) if err != nil { log.Printf("server handshake failed: %v", err) return } defer c.Close(websocket.StatusInternalError, "") + if c.Subprotocol() == "" { + c.Close(websocket.StatusPolicyViolation, "cannot communicate with the default protocol") + return + } + echo := func() error { ctx, cancel := context.WithTimeout(r.Context(), time.Minute) defer cancel() diff --git a/json.go b/json.go index 53869b5917b2605df436bc4de36f8412200bd46e..24e6f3184c4aaeb795dc7305e9856249ae373ad2 100644 --- a/json.go +++ b/json.go @@ -28,7 +28,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error { return err } - if typ != DataText { + if typ != MessageText { return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ) } @@ -39,6 +39,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error { if err != nil { return xerrors.Errorf("failed to decode json: %w", err) } + return nil } @@ -52,7 +53,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w := jc.Conn.Write(ctx, DataText) + w := jc.Conn.Write(ctx, MessageText) e := json.NewEncoder(w) err := e.Encode(v) diff --git a/messagetype.go b/messagetype.go new file mode 100644 index 0000000000000000000000000000000000000000..54276b3b33d13166f7d6bf01d6347d2a0d39e863 --- /dev/null +++ b/messagetype.go @@ -0,0 +1,12 @@ +package websocket + +// MessageType represents the Opcode of a WebSocket data frame. +type MessageType int + +//go:generate go run golang.org/x/tools/cmd/stringer -type=MessageType + +// MessageType constants. +const ( + MessageText MessageType = MessageType(opText) + MessageBinary MessageType = MessageType(opBinary) +) diff --git a/messagetype_string.go b/messagetype_string.go new file mode 100644 index 0000000000000000000000000000000000000000..bc62db93b22341aa36a0eb73b51cdb0bcf36678f --- /dev/null +++ b/messagetype_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type=MessageType"; DO NOT EDIT. + +package websocket + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[MessageText-1] + _ = x[MessageBinary-2] +} + +const _MessageType_name = "MessageTextMessageBinary" + +var _MessageType_index = [...]uint8{0, 11, 24} + +func (i MessageType) String() string { + i -= 1 + if i < 0 || i >= MessageType(len(_MessageType_index)-1) { + return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] +} diff --git a/websocket.go b/websocket.go index 717cc75550066688278e5cdaa05711bb029e44e3..52b5d8dba2ce40146c91bfaca524a5e181d58160 100644 --- a/websocket.go +++ b/websocket.go @@ -23,11 +23,9 @@ type control struct { type Conn struct { subprotocol string br *bufio.Reader - // TODO switch to []byte for write buffering because for messages larger than buffers, there will always be 3 writes. One for the frame, one for the message, one for the fin. - // Also will help for compression. - bw *bufio.Writer - closer io.Closer - client bool + bw *bufio.Writer + closer io.Closer + client bool closeOnce sync.Once closeErr error @@ -36,7 +34,7 @@ type Conn struct { // Writers should send on write to begin sending // a message and then follow that up with some data // on writeBytes. - write chan DataType + write chan MessageType control chan control writeBytes chan []byte writeDone chan struct{} @@ -45,9 +43,10 @@ type Conn struct { // Then send a byte slice to readBytes to read into it. // The n of bytes read will be sent on readDone once the read into a slice is complete. // readDone will receive 0 when EOF is reached. - read chan opcode - readBytes chan []byte - readDone chan int + read chan opcode + readBytes chan []byte + readDone chan int + readerDone chan struct{} } func (c *Conn) close(err error) { @@ -77,12 +76,13 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) - c.write = make(chan DataType) + c.write = make(chan MessageType) c.control = make(chan control) c.writeDone = make(chan struct{}) c.read = make(chan opcode) c.readDone = make(chan int) c.readBytes = make(chan []byte) + c.readerDone = make(chan struct{}) runtime.SetFinalizer(c, func(c *Conn) { c.Close(StatusInternalError, "websocket: connection ended up being garbage collected") @@ -120,7 +120,7 @@ messageLoop: for { c.writeBytes = make(chan []byte) - var dataType DataType + var dataType MessageType select { case <-c.closed: return @@ -321,7 +321,7 @@ func (c *Conn) readLoop() { select { case <-c.closed: return - case c.readDone <- 0: + case c.readerDone <- struct{}{}: } } } @@ -407,7 +407,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // a WebSocket data frame of type dataType to the connection. // Ensure you close the messageWriter once you have written to entire message. // Concurrent calls to messageWriter are ok. -func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser { +func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser { + // TODO acquire write here, move state into Conn and make messageWriter allocation free. return &messageWriter{ c: c, ctx: ctx, @@ -418,7 +419,7 @@ func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser { // messageWriter enables writing to a WebSocket connection. // Ensure you close the messageWriter once you have written to entire message. type messageWriter struct { - datatype DataType + datatype MessageType ctx context.Context c *Conn acquiredLock bool @@ -489,12 +490,16 @@ func (w *messageWriter) Close() error { // Please use SetContext on the reader to bound the read operation. // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames. -func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) { +func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) { + // TODO error if the reader is not done select { + case <-c.readerDone: + // The previous reader just hit a io.EOF, we handle it for users + return c.Read(ctx) case <-c.closed: return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr) case opcode := <-c.read: - return DataType(opcode), &messageReader{ + return MessageType(opcode), messageReader{ ctx: ctx, c: c, }, nil @@ -510,13 +515,26 @@ type messageReader struct { } // Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (n int, err error) { +func (r messageReader) Read(p []byte) (int, error) { + n, err := r.read(p) + if err != nil { + // Have to return io.EOF directly for now. + if err == io.EOF { + return 0, io.EOF + } + return n, xerrors.Errorf("failed to read: %w", err) + } + return n, nil +} + +func (r messageReader) read(p []byte) (int, error) { select { case <-r.c.closed: return 0, r.c.closeErr - case <-r.c.readDone: + case <-r.c.readerDone: return 0, io.EOF case r.c.readBytes <- p: + // TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes. select { case <-r.c.closed: return 0, r.c.closeErr diff --git a/websocket_test.go b/websocket_test.go index 2133482fb44403f733dca3834acb5d2071a78c9d..868b69a37ccb7e0700c89627eeb1403a02d6da4f 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -143,9 +143,30 @@ func TestHandshake(t *testing.T) { }, }, { - name: "authorizedOrigin", + name: "acceptSecureOrigin", server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOrigins("har.bar.com", "example.com")) + c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin()) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + h := http.Header{} + h.Set("Origin", "https://127.0.0.1") + c, _, err := websocket.Dial(ctx, u, websocket.DialHeader(h)) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + { + name: "acceptInsecureOrigin", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptInsecureOrigin()) if err != nil { return err }