good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
Showing
with 1211 additions and 260 deletions
//go:build !js
package websocket
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"math"
"github.com/coder/websocket/internal/errd"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
type header struct {
fin bool
rsv1 bool
rsv2 bool
rsv3 bool
opcode opcode
payloadLength int64
masked bool
maskKey uint32
}
// readFrameHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2.
func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
defer errd.Wrap(&err, "failed to read frame header")
b, err := r.ReadByte()
if err != nil {
return header{}, err
}
h.fin = b&(1<<7) != 0
h.rsv1 = b&(1<<6) != 0
h.rsv2 = b&(1<<5) != 0
h.rsv3 = b&(1<<4) != 0
h.opcode = opcode(b & 0xf)
b, err = r.ReadByte()
if err != nil {
return header{}, err
}
h.masked = b&(1<<7) != 0
payloadLength := b &^ (1 << 7)
switch {
case payloadLength < 126:
h.payloadLength = int64(payloadLength)
case payloadLength == 126:
_, err = io.ReadFull(r, readBuf[:2])
h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
case payloadLength == 127:
_, err = io.ReadFull(r, readBuf)
h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
}
if err != nil {
return header{}, err
}
if h.payloadLength < 0 {
return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
}
if h.masked {
_, err = io.ReadFull(r, readBuf[:4])
if err != nil {
return header{}, err
}
h.maskKey = binary.LittleEndian.Uint32(readBuf)
}
return h, nil
}
// maxControlPayload is the maximum length of a control frame payload.
// See https://tools.ietf.org/html/rfc6455#section-5.5.
const maxControlPayload = 125
// writeFrameHeader writes the bytes of the header to w.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
defer errd.Wrap(&err, "failed to write frame header")
var b byte
if h.fin {
b |= 1 << 7
}
if h.rsv1 {
b |= 1 << 6
}
if h.rsv2 {
b |= 1 << 5
}
if h.rsv3 {
b |= 1 << 4
}
b |= byte(h.opcode)
err = w.WriteByte(b)
if err != nil {
return err
}
lengthByte := byte(0)
if h.masked {
lengthByte |= 1 << 7
}
switch {
case h.payloadLength > math.MaxUint16:
lengthByte |= 127
case h.payloadLength > 125:
lengthByte |= 126
case h.payloadLength >= 0:
lengthByte |= byte(h.payloadLength)
}
err = w.WriteByte(lengthByte)
if err != nil {
return err
}
switch {
case h.payloadLength > math.MaxUint16:
binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
_, err = w.Write(buf)
case h.payloadLength > 125:
binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
_, err = w.Write(buf[:2])
}
if err != nil {
return err
}
if h.masked {
binary.LittleEndian.PutUint32(buf, h.maskKey)
_, err = w.Write(buf[:4])
if err != nil {
return err
}
}
return nil
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"bytes"
"encoding/binary"
"math/bits"
"math/rand"
"strconv"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/coder/websocket/internal/test/assert"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func randBool() bool {
return rand.Intn(1) == 0
}
func TestHeader(t *testing.T) {
t.Parallel()
t.Run("readNegativeLength", func(t *testing.T) {
t.Parallel()
b := writeHeader(nil, header{
payloadLength: 1<<16 + 1,
})
// Make length negative
b[2] |= 1 << 7
r := bytes.NewReader(b)
_, err := readHeader(nil, r)
if err == nil {
t.Fatalf("unexpected error value: %+v", err)
}
})
t.Run("lengths", func(t *testing.T) {
t.Parallel()
......@@ -45,12 +26,12 @@ func TestHeader(t *testing.T) {
124,
125,
126,
4096,
16384,
127,
65534,
65535,
65536,
65537,
131072,
}
for _, n := range lengths {
......@@ -68,20 +49,24 @@ func TestHeader(t *testing.T) {
t.Run("fuzz", func(t *testing.T) {
t.Parallel()
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randBool := func() bool {
return r.Intn(2) == 0
}
for i := 0; i < 10000; i++ {
h := header{
fin: randBool(),
rsv1: randBool(),
rsv2: randBool(),
rsv3: randBool(),
opcode: opcode(rand.Intn(1 << 4)),
opcode: opcode(r.Intn(16)),
masked: randBool(),
payloadLength: rand.Int63(),
payloadLength: r.Int63(),
}
if h.masked {
rand.Read(h.maskKey[:])
h.maskKey = r.Uint32()
}
testHeader(t, h)
......@@ -90,18 +75,33 @@ func TestHeader(t *testing.T) {
}
func testHeader(t *testing.T, h header) {
b := writeHeader(nil, h)
r := bytes.NewReader(b)
h2, err := readHeader(nil, r)
if err != nil {
t.Logf("header: %#v", h)
t.Logf("bytes: %b", b)
t.Fatalf("failed to read header: %v", err)
}
if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) {
t.Logf("header: %#v", h)
t.Logf("bytes: %b", b)
t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
}
b := &bytes.Buffer{}
w := bufio.NewWriter(b)
r := bufio.NewReader(b)
err := writeFrameHeader(h, w, make([]byte, 8))
assert.Success(t, err)
err = w.Flush()
assert.Success(t, err)
h2, err := readFrameHeader(r, make([]byte, 8))
assert.Success(t, err)
assert.Equal(t, "read header", h, h2)
}
func Test_mask(t *testing.T) {
t.Parallel()
key := []byte{0xa, 0xb, 0xc, 0xff}
key32 := binary.LittleEndian.Uint32(key)
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
gotKey32 := mask(p, key32)
expP := []byte{0, 0, 0, 0x0d, 0x6}
assert.Equal(t, "p", expP, p)
expKey32 := bits.RotateLeft32(key32, -8)
assert.Equal(t, "key32", expKey32, gotKey32)
}
module nhooyr.io/websocket
module github.com/coder/websocket
go 1.12
require (
github.com/golang/protobuf v1.3.1
github.com/google/go-cmp v0.2.0
github.com/kr/pretty v0.1.0 // indirect
go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16
golang.org/x/lint v0.0.0-20190409202823-959b441ac422
golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6
golang.org/x/text v0.3.2 // indirect
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
golang.org/x/tools v0.0.0-20190429184909-35c670923e21
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522
mvdan.cc/sh v2.6.4+incompatible
)
go 1.23
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16 h1:3gGa1bM0nG7Ruhu5b7wKnoOOwAD/fJ8iyyAcpOzDG3A=
go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16/go.mod h1:iKV5yK9t+J5nG9O3uF6KYdPEz3dyfMyB15MN1rbQ8Qw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/lint v0.0.0-20190409202823-959b441ac422 h1:QzoH/1pFpZguR8NrRHLcO6jKqfv2zpuSqZLgdm7ZmjI=
golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 h1:FP8hkuE6yUEaJnK7O2eTuejKWwW+Rhfj80dQ2JcKxCU=
golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190429184909-35c670923e21 h1:Kjcw+D2LTzLmxOHrMK9uvYP/NigJ0EdwMgzt6EU+Ghs=
golang.org/x/tools v0.0.0-20190429184909-35c670923e21/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM=
mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8=
package websocket
import (
"encoding/binary"
"fmt"
"io"
"math"
"golang.org/x/xerrors"
)
// First byte contains fin, rsv1, rsv2, rsv3.
// Second byte contains mask flag and payload len.
// Next 8 bytes are the maximum extended payload length.
// Last 4 bytes are the mask key.
// https://tools.ietf.org/html/rfc6455#section-5.2
const maxHeaderSize = 1 + 1 + 8 + 4
// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type header struct {
fin bool
rsv1 bool
rsv2 bool
rsv3 bool
opcode opcode
payloadLength int64
masked bool
maskKey [4]byte
}
func makeWriteHeaderBuf() []byte {
return make([]byte, maxHeaderSize)
}
// bytes returns the bytes of the header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeHeader(b []byte, h header) []byte {
if b == nil {
b = makeWriteHeaderBuf()
}
b = b[:2]
b[0] = 0
if h.fin {
b[0] |= 1 << 7
}
if h.rsv1 {
b[0] |= 1 << 6
}
if h.rsv2 {
b[0] |= 1 << 5
}
if h.rsv3 {
b[0] |= 1 << 4
}
b[0] |= byte(h.opcode)
switch {
case h.payloadLength < 0:
panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength))
case h.payloadLength <= 125:
b[1] = byte(h.payloadLength)
case h.payloadLength <= math.MaxUint16:
b[1] = 126
b = b[:len(b)+2]
binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength))
default:
b[1] = 127
b = b[:len(b)+8]
binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength))
}
if h.masked {
b[1] |= 1 << 7
b = b[:len(b)+4]
copy(b[len(b)-4:], h.maskKey[:])
}
return b
}
func makeReadHeaderBuf() []byte {
return make([]byte, maxHeaderSize-2)
}
// readHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func readHeader(b []byte, r io.Reader) (header, error) {
if b == nil {
b = makeReadHeaderBuf()
}
// We read the first two bytes first so that we know
// exactly how long the header is.
b = b[:2]
_, err := io.ReadFull(r, b)
if err != nil {
return header{}, err
}
var h header
h.fin = b[0]&(1<<7) != 0
h.rsv1 = b[0]&(1<<6) != 0
h.rsv2 = b[0]&(1<<5) != 0
h.rsv3 = b[0]&(1<<4) != 0
h.opcode = opcode(b[0] & 0xf)
var extra int
h.masked = b[1]&(1<<7) != 0
if h.masked {
extra += 4
}
payloadLength := b[1] &^ (1 << 7)
switch {
case payloadLength < 126:
h.payloadLength = int64(payloadLength)
case payloadLength == 126:
extra += 2
case payloadLength == 127:
extra += 8
}
if extra == 0 {
return h, nil
}
b = b[:extra]
_, err = io.ReadFull(r, b)
if err != nil {
return header{}, err
}
switch {
case payloadLength == 126:
h.payloadLength = int64(binary.BigEndian.Uint16(b))
b = b[2:]
case payloadLength == 127:
h.payloadLength = int64(binary.BigEndian.Uint64(b))
if h.payloadLength < 0 {
return header{}, xerrors.Errorf("header with negative payload length: %v", h.payloadLength)
}
b = b[8:]
}
if h.masked {
copy(h.maskKey[:], b)
}
return h, nil
}
//go:build !js
package websocket
import (
"net/http"
)
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
// hijacker returns the Hijacker interface of the http.ResponseWriter.
// It follows the Unwrap method of the http.ResponseWriter if available,
// matching the behavior of http.ResponseController. If the Hijacker
// interface is not found, it returns false.
//
// Since the http.ResponseController is not available in Go 1.19, and
// does not support checking the presence of the Hijacker interface,
// this function is used to provide a consistent way to check for the
// Hijacker interface across Go versions.
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t, true
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil, false
}
}
}
//go:build !js && go1.20
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
_, _, err := http.NewResponseController(w).Hijack()
assert.Contains(t, err, "haha")
hj, ok := hijacker(w)
assert.Equal(t, "hijacker found", ok, true)
_, _, err = hj.Hijack()
assert.Contains(t, err, "haha")
}
......@@ -5,16 +5,17 @@ import (
"sync"
)
var bpool sync.Pool
var bpool = sync.Pool{
New: func() any {
return &bytes.Buffer{}
},
}
// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
b, ok := bpool.Get().(*bytes.Buffer)
if !ok {
b = &bytes.Buffer{}
}
return b
b := bpool.Get()
return b.(*bytes.Buffer)
}
// Put returns a buffer into the pool.
......
package bpool
import (
"strconv"
"sync"
"testing"
)
func BenchmarkSyncPool(b *testing.B) {
sizes := []int{
2,
16,
32,
64,
128,
256,
512,
4096,
16384,
}
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
b.Run("allocate", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf := make([]byte, size)
_ = buf
}
})
b.Run("pool", func(b *testing.B) {
b.ReportAllocs()
p := sync.Pool{}
for i := 0; i < b.N; i++ {
buf := p.Get()
if buf == nil {
buf = make([]byte, size)
}
p.Put(buf)
}
})
})
}
}
package errd
import (
"fmt"
)
// Wrap wraps err with fmt.Errorf if err is non nil.
// Intended for use with defer and a named error return.
// Inspired by https://github.com/golang/go/issues/32676.
func Wrap(err *error, f string, v ...interface{}) {
if *err != nil {
*err = fmt.Errorf(f+": %w", append(v, *err)...)
}
}
# Examples
This directory contains more involved examples unsuitable
for display with godoc.
# Chat Example
This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket.
```bash
$ cd examples/chat
$ go run . localhost:0
listening on ws://127.0.0.1:51055
```
Visit the printed URL to submit and view broadcasted messages in a browser.
![Image of Example](https://i.imgur.com/VwJl9Bh.png)
## Structure
The frontend is contained in `index.html`, `index.js` and `index.css`. It sets up the
DOM with a scrollable div at the top that is populated with new messages as they are broadcast.
At the bottom it adds a form to submit messages.
The messages are received via the WebSocket `/subscribe` endpoint and published via
the HTTP POST `/publish` endpoint. The reason for not publishing messages over the WebSocket
is so that you can easily publish a message with curl.
The server portion is `main.go` and `chat.go` and implements serving the static frontend
assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoint.
The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
`index.html` and then `index.js`.
There are two automated tests for the server included in `chat_test.go`. The first is a simple one
client echo test. It publishes a single message and ensures it's received.
The second is a complex concurrency test where 10 clients send 128 unique messages
of max 128 bytes concurrently. The test ensures all messages are seen by every client.
package main
import (
"context"
"errors"
"io"
"log"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/coder/websocket"
)
// chatServer enables broadcasting to a set of subscribers.
type chatServer struct {
// subscriberMessageBuffer controls the max number
// of messages that can be queued for a subscriber
// before it is kicked.
//
// Defaults to 16.
subscriberMessageBuffer int
// publishLimiter controls the rate limit applied to the publish endpoint.
//
// Defaults to one publish every 100ms with a burst of 8.
publishLimiter *rate.Limiter
// logf controls where logs are sent.
// Defaults to log.Printf.
logf func(f string, v ...interface{})
// serveMux routes the various endpoints to the appropriate handler.
serveMux http.ServeMux
subscribersMu sync.Mutex
subscribers map[*subscriber]struct{}
}
// newChatServer constructs a chatServer with the defaults.
func newChatServer() *chatServer {
cs := &chatServer{
subscriberMessageBuffer: 16,
logf: log.Printf,
subscribers: make(map[*subscriber]struct{}),
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
}
cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
cs.serveMux.HandleFunc("/publish", cs.publishHandler)
return cs
}
// subscriber represents a subscriber.
// Messages are sent on the msgs channel and if the client
// cannot keep up with the messages, closeSlow is called.
type subscriber struct {
msgs chan []byte
closeSlow func()
}
func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cs.serveMux.ServeHTTP(w, r)
}
// subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
err := cs.subscribe(w, r)
if errors.Is(err, context.Canceled) {
return
}
if websocket.CloseStatus(err) == websocket.StatusNormalClosure ||
websocket.CloseStatus(err) == websocket.StatusGoingAway {
return
}
if err != nil {
cs.logf("%v", err)
return
}
}
// publishHandler reads the request body with a limit of 8192 bytes and then publishes
// the received message.
func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
body := http.MaxBytesReader(w, r.Body, 8192)
msg, err := io.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
return
}
cs.publish(msg)
w.WriteHeader(http.StatusAccepted)
}
// subscribe subscribes the given WebSocket to all broadcast messages.
// It creates a subscriber with a buffered msgs chan to give some room to slower
// connections and then registers the subscriber. It then listens for all messages
// and writes them to the WebSocket. If the context is cancelled or
// an error occurs, it returns and deletes the subscription.
//
// It uses CloseRead to keep reading from the connection to process control
// messages and cancel the context if the connection drops.
func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error {
var mu sync.Mutex
var c *websocket.Conn
var closed bool
s := &subscriber{
msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() {
mu.Lock()
defer mu.Unlock()
closed = true
if c != nil {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
}
},
}
cs.addSubscriber(s)
defer cs.deleteSubscriber(s)
c2, err := websocket.Accept(w, r, nil)
if err != nil {
return err
}
mu.Lock()
if closed {
mu.Unlock()
return net.ErrClosed
}
c = c2
mu.Unlock()
defer c.CloseNow()
ctx := c.CloseRead(context.Background())
for {
select {
case msg := <-s.msgs:
err := writeTimeout(ctx, time.Second*5, c, msg)
if err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
// publish publishes the msg to all subscribers.
// It never blocks and so messages to slow subscribers
// are dropped.
func (cs *chatServer) publish(msg []byte) {
cs.subscribersMu.Lock()
defer cs.subscribersMu.Unlock()
cs.publishLimiter.Wait(context.Background())
for s := range cs.subscribers {
select {
case s.msgs <- msg:
default:
go s.closeSlow()
}
}
}
// addSubscriber registers a subscriber.
func (cs *chatServer) addSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
cs.subscribers[s] = struct{}{}
cs.subscribersMu.Unlock()
}
// deleteSubscriber deletes the given subscriber.
func (cs *chatServer) deleteSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
delete(cs.subscribers, s)
cs.subscribersMu.Unlock()
}
func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return c.Write(ctx, websocket.MessageText, msg)
}
package main
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"golang.org/x/time/rate"
"github.com/coder/websocket"
)
func Test_chatServer(t *testing.T) {
t.Parallel()
// This is a simple echo test with a single client.
// The client sends a message and ensures it receives
// it on its WebSocket.
t.Run("simple", func(t *testing.T) {
t.Parallel()
url, closeFn := setupTest(t)
defer closeFn()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
cl, err := newClient(ctx, url)
assertSuccess(t, err)
defer cl.Close()
expMsg := randString(512)
err = cl.publish(ctx, expMsg)
assertSuccess(t, err)
msg, err := cl.nextMessage()
assertSuccess(t, err)
if expMsg != msg {
t.Fatalf("expected %v but got %v", expMsg, msg)
}
})
// This test is a complex concurrency test.
// 10 clients are started that send 128 different
// messages of max 128 bytes concurrently.
//
// The test verifies that every message is seen by every client
// and no errors occur anywhere.
t.Run("concurrency", func(t *testing.T) {
t.Parallel()
const nmessages = 128
const maxMessageSize = 128
const nclients = 16
url, closeFn := setupTest(t)
defer closeFn()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var clients []*client
var clientMsgs []map[string]struct{}
for i := 0; i < nclients; i++ {
cl, err := newClient(ctx, url)
assertSuccess(t, err)
defer cl.Close()
clients = append(clients, cl)
clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize))
}
allMessages := make(map[string]struct{})
for _, msgs := range clientMsgs {
for m := range msgs {
allMessages[m] = struct{}{}
}
}
var wg sync.WaitGroup
for i, cl := range clients {
i := i
cl := cl
wg.Add(1)
go func() {
defer wg.Done()
err := cl.publishMsgs(ctx, clientMsgs[i])
if err != nil {
t.Errorf("client %d failed to publish all messages: %v", i, err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
err := testAllMessagesReceived(cl, nclients*nmessages, allMessages)
if err != nil {
t.Errorf("client %d failed to receive all messages: %v", i, err)
}
}()
}
wg.Wait()
})
}
// setupTest sets up chatServer that can be used
// via the returned url.
//
// Defer closeFn to ensure everything is cleaned up at
// the end of the test.
//
// chatServer logs will be logged via t.Logf.
func setupTest(t *testing.T) (url string, closeFn func()) {
cs := newChatServer()
cs.logf = t.Logf
// To ensure tests run quickly under even -race.
cs.subscriberMessageBuffer = 4096
cs.publishLimiter.SetLimit(rate.Inf)
s := httptest.NewServer(cs)
return s.URL, func() {
s.Close()
}
}
// testAllMessagesReceived ensures that after n reads, all msgs in msgs
// have been read.
func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error {
msgs = cloneMessages(msgs)
for i := 0; i < n; i++ {
msg, err := cl.nextMessage()
if err != nil {
return err
}
delete(msgs, msg)
}
if len(msgs) != 0 {
return fmt.Errorf("did not receive all expected messages: %q", msgs)
}
return nil
}
func cloneMessages(msgs map[string]struct{}) map[string]struct{} {
msgs2 := make(map[string]struct{}, len(msgs))
for m := range msgs {
msgs2[m] = struct{}{}
}
return msgs2
}
func randMessages(n, maxMessageLength int) map[string]struct{} {
msgs := make(map[string]struct{})
for i := 0; i < n; i++ {
m := randString(randInt(maxMessageLength))
if _, ok := msgs[m]; ok {
i--
continue
}
msgs[m] = struct{}{}
}
return msgs
}
func assertSuccess(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
type client struct {
url string
c *websocket.Conn
}
func newClient(ctx context.Context, url string) (*client, error) {
c, _, err := websocket.Dial(ctx, url+"/subscribe", nil)
if err != nil {
return nil, err
}
cl := &client{
url: url,
c: c,
}
return cl, nil
}
func (cl *client) publish(ctx context.Context, msg string) (err error) {
defer func() {
if err != nil {
cl.c.Close(websocket.StatusInternalError, "publish failed")
}
}()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
return fmt.Errorf("publish request failed: %v", resp.StatusCode)
}
return nil
}
func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error {
for m := range msgs {
err := cl.publish(ctx, m)
if err != nil {
return err
}
}
return nil
}
func (cl *client) nextMessage() (string, error) {
typ, b, err := cl.c.Read(context.Background())
if err != nil {
return "", err
}
if typ != websocket.MessageText {
cl.c.Close(websocket.StatusUnsupportedData, "expected text message")
return "", fmt.Errorf("expected text message but got %v", typ)
}
return string(b), nil
}
func (cl *client) Close() error {
return cl.c.Close(websocket.StatusNormalClosure, "")
}
// randString generates a random string with length n.
func randString(n int) string {
b := make([]byte, n)
_, err := rand.Reader.Read(b)
if err != nil {
panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
}
s := strings.ToValidUTF8(string(b), "_")
s = strings.ReplaceAll(s, "\x00", "_")
if len(s) > n {
return s[:n]
}
if len(s) < n {
// Pad with =
extra := n - len(s)
return s + strings.Repeat("=", extra)
}
return s
}
// randInt returns a randomly generated integer between [0, max).
func randInt(max int) int {
x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
panic(fmt.Sprintf("failed to get random int: %v", err))
}
return int(x.Int64())
}
body {
width: 100vw;
min-width: 320px;
}
#root {
padding: 40px 20px;
max-width: 600px;
margin: auto;
height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
#root > * + * {
margin: 20px 0 0 0;
}
/* 100vh on safari does not include the bottom bar. */
@supports (-webkit-overflow-scrolling: touch) {
#root {
height: 85vh;
}
}
#message-log {
width: 100%;
flex-grow: 1;
overflow: auto;
}
#message-log p:first-child {
margin: 0;
}
#message-log > * + * {
margin: 10px 0 0 0;
}
#publish-form-container {
width: 100%;
}
#publish-form {
width: 100%;
display: flex;
height: 40px;
}
#publish-form > * + * {
margin: 0 0 0 10px;
}
#publish-form input[type='text'] {
flex-grow: 1;
-moz-appearance: none;
-webkit-appearance: none;
word-break: normal;
border-radius: 5px;
border: 1px solid #ccc;
}
#publish-form input[type='submit'] {
color: white;
background-color: black;
border-radius: 5px;
padding: 5px 10px;
border: none;
}
#publish-form input[type='submit']:hover {
background-color: red;
}
#publish-form input[type='submit']:active {
background-color: red;
}
<!doctype html>
<html lang="en-CA">
<head>
<meta charset="UTF-8" />
<title>github.com/coder/websocket - Chat Example</title>
<meta name="viewport" content="width=device-width" />
<link href="https://unpkg.com/sanitize.css" rel="stylesheet" />
<link href="https://unpkg.com/sanitize.css/typography.css" rel="stylesheet" />
<link href="https://unpkg.com/sanitize.css/forms.css" rel="stylesheet" />
<link href="/index.css" rel="stylesheet" />
</head>
<body>
<div id="root">
<div id="message-log"></div>
<div id="publish-form-container">
<form id="publish-form">
<input name="message" id="message-input" type="text" />
<input value="Submit" type="submit" />
</form>
</div>
</div>
<script type="text/javascript" src="/index.js"></script>
</body>
</html>
;(() => {
// expectingMessage is set to true
// if the user has just submitted a message
// and so we should scroll the next message into view when received.
let expectingMessage = false
function dial() {
const conn = new WebSocket(`ws://${location.host}/subscribe`)
conn.addEventListener('close', ev => {
appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true)
if (ev.code !== 1001) {
appendLog('Reconnecting in 1s', true)
setTimeout(dial, 1000)
}
})
conn.addEventListener('open', ev => {
console.info('websocket connected')
})
// This is where we handle messages received.
conn.addEventListener('message', ev => {
if (typeof ev.data !== 'string') {
console.error('unexpected message type', typeof ev.data)
return
}
const p = appendLog(ev.data)
if (expectingMessage) {
p.scrollIntoView()
expectingMessage = false
}
})
}
dial()
const messageLog = document.getElementById('message-log')
const publishForm = document.getElementById('publish-form')
const messageInput = document.getElementById('message-input')
// appendLog appends the passed text to messageLog.
function appendLog(text, error) {
const p = document.createElement('p')
// Adding a timestamp to each message makes the log easier to read.
p.innerText = `${new Date().toLocaleTimeString()}: ${text}`
if (error) {
p.style.color = 'red'
p.style.fontStyle = 'bold'
}
messageLog.append(p)
return p
}
appendLog('Submit a message to get started!')
// onsubmit publishes the message from the user when the form is submitted.
publishForm.onsubmit = async ev => {
ev.preventDefault()
const msg = messageInput.value
if (msg === '') {
return
}
messageInput.value = ''
expectingMessage = true
try {
const resp = await fetch('/publish', {
method: 'POST',
body: msg,
})
if (resp.status !== 202) {
throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`)
}
} catch (err) {
appendLog(`Publish failed: ${err.message}`, true)
}
}
})()
package main
import (
"context"
"errors"
"log"
"net"
"net/http"
"os"
"os/signal"
"time"
)
func main() {
log.SetFlags(0)
err := run()
if err != nil {
log.Fatal(err)
}
}
// run initializes the chatServer and then
// starts a http.Server for the passed in address.
func run() error {
if len(os.Args) < 2 {
return errors.New("please provide an address to listen on as the first argument")
}
l, err := net.Listen("tcp", os.Args[1])
if err != nil {
return err
}
log.Printf("listening on ws://%v", l.Addr())
cs := newChatServer()
s := &http.Server{
Handler: cs,
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
}
errc := make(chan error, 1)
go func() {
errc <- s.Serve(l)
}()
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
select {
case err := <-errc:
log.Printf("failed to serve: %v", err)
case sig := <-sigs:
log.Printf("terminating: %v", sig)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.Shutdown(ctx)
}
# Echo Example
This directory contains a echo server example using github.com/coder/websocket.
```bash
$ cd examples/echo
$ go run . localhost:0
listening on ws://127.0.0.1:51055
```
You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages
written will be echoed back.
## Structure
The server is in `server.go` and is implemented as a `http.HandlerFunc` that accepts the WebSocket
and then reads all messages and writes them exactly as is back to the connection.
`server_test.go` contains a small unit test to verify it works correctly.
`main.go` brings it all together so that you can run it and play around with it.
package main
import (
"context"
"errors"
"log"
"net"
"net/http"
"os"
"os/signal"
"time"
)
func main() {
log.SetFlags(0)
err := run()
if err != nil {
log.Fatal(err)
}
}
// run starts a http.Server for the passed in address
// with all requests handled by echoServer.
func run() error {
if len(os.Args) < 2 {
return errors.New("please provide an address to listen on as the first argument")
}
l, err := net.Listen("tcp", os.Args[1])
if err != nil {
return err
}
log.Printf("listening on ws://%v", l.Addr())
s := &http.Server{
Handler: echoServer{
logf: log.Printf,
},
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
}
errc := make(chan error, 1)
go func() {
errc <- s.Serve(l)
}()
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
select {
case err := <-errc:
log.Printf("failed to serve: %v", err)
case sig := <-sigs:
log.Printf("terminating: %v", sig)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.Shutdown(ctx)
}