good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
package websocket
import (
"encoding/binary"
"math/bits"
)
// maskGo applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func maskGo(b []byte, key uint32) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}
// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}
// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
#include "textflag.h"
// func maskAsm(b *byte, len int, key uint32)
TEXT ·maskAsm(SB), NOSPLIT, $0-28
// AX = b
// CX = len (left length)
// SI = key (uint32)
// DI = uint64(SI) | uint64(SI)<<32
MOVQ b+0(FP), AX
MOVQ len+8(FP), CX
MOVL key+16(FP), SI
// calculate the DI
// DI = SI<<32 | SI
MOVL SI, DI
MOVQ DI, DX
SHLQ $32, DI
ORQ DX, DI
CMPQ CX, $15
JLE less_than_16
CMPQ CX, $63
JLE less_than_64
CMPQ CX, $128
JLE sse
TESTQ $31, AX
JNZ unaligned
unaligned_loop_1byte:
XORB SI, (AX)
INCQ AX
DECQ CX
ROLL $24, SI
TESTQ $7, AX
JNZ unaligned_loop_1byte
// calculate DI again since SI was modified
// DI = SI<<32 | SI
MOVL SI, DI
MOVQ DI, DX
SHLQ $32, DI
ORQ DX, DI
TESTQ $31, AX
JZ sse
unaligned:
TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b.
JNZ unaligned_loop_1byte
unaligned_loop:
// we don't need to check the CX since we know it's above 128
XORQ DI, (AX)
ADDQ $8, AX
SUBQ $8, CX
TESTQ $31, AX
JNZ unaligned_loop
JMP sse
sse:
CMPQ CX, $0x40
JL less_than_64
MOVQ DI, X0
PUNPCKLQDQ X0, X0
sse_loop:
MOVOU 0*16(AX), X1
MOVOU 1*16(AX), X2
MOVOU 2*16(AX), X3
MOVOU 3*16(AX), X4
PXOR X0, X1
PXOR X0, X2
PXOR X0, X3
PXOR X0, X4
MOVOU X1, 0*16(AX)
MOVOU X2, 1*16(AX)
MOVOU X3, 2*16(AX)
MOVOU X4, 3*16(AX)
ADDQ $0x40, AX
SUBQ $0x40, CX
CMPQ CX, $0x40
JAE sse_loop
less_than_64:
TESTQ $32, CX
JZ less_than_32
XORQ DI, (AX)
XORQ DI, 8(AX)
XORQ DI, 16(AX)
XORQ DI, 24(AX)
ADDQ $32, AX
less_than_32:
TESTQ $16, CX
JZ less_than_16
XORQ DI, (AX)
XORQ DI, 8(AX)
ADDQ $16, AX
less_than_16:
TESTQ $8, CX
JZ less_than_8
XORQ DI, (AX)
ADDQ $8, AX
less_than_8:
TESTQ $4, CX
JZ less_than_4
XORL SI, (AX)
ADDQ $4, AX
less_than_4:
TESTQ $2, CX
JZ less_than_2
XORW SI, (AX)
ROLL $16, SI
ADDQ $2, AX
less_than_2:
TESTQ $1, CX
JZ done
XORB SI, (AX)
ROLL $24, SI
done:
MOVL SI, ret+24(FP)
RET
#include "textflag.h"
// func maskAsm(b *byte, len int, key uint32)
TEXT ·maskAsm(SB), NOSPLIT, $0-28
// R0 = b
// R1 = len
// R3 = key (uint32)
// R2 = uint64(key)<<32 | uint64(key)
MOVD b_ptr+0(FP), R0
MOVD b_len+8(FP), R1
MOVWU key+16(FP), R3
MOVD R3, R2
ORR R2<<32, R2, R2
VDUP R2, V0.D2
CMP $64, R1
BLT less_than_64
loop_64:
VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16]
VEOR V1.B16, V0.B16, V1.B16
VEOR V2.B16, V0.B16, V2.B16
VEOR V3.B16, V0.B16, V3.B16
VEOR V4.B16, V0.B16, V4.B16
VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0)
SUBS $64, R1
CMP $64, R1
BGE loop_64
less_than_64:
CBZ R1, end
TBZ $5, R1, less_than_32
VLD1 (R0), [V1.B16, V2.B16]
VEOR V1.B16, V0.B16, V1.B16
VEOR V2.B16, V0.B16, V2.B16
VST1.P [V1.B16, V2.B16], 32(R0)
less_than_32:
TBZ $4, R1, less_than_16
LDP (R0), (R11, R12)
EOR R11, R2, R11
EOR R12, R2, R12
STP.P (R11, R12), 16(R0)
less_than_16:
TBZ $3, R1, less_than_8
MOVD (R0), R11
EOR R2, R11, R11
MOVD.P R11, 8(R0)
less_than_8:
TBZ $2, R1, less_than_4
MOVWU (R0), R11
EORW R2, R11, R11
MOVWU.P R11, 4(R0)
less_than_4:
TBZ $1, R1, less_than_2
MOVHU (R0), R11
EORW R3, R11, R11
MOVHU.P R11, 2(R0)
RORW $16, R3
less_than_2:
TBZ $0, R1, end
MOVBU (R0), R11
EORW R3, R11, R11
MOVBU.P R11, 1(R0)
RORW $8, R3
end:
MOVWU R3, ret+24(FP)
RET
//go:build amd64 || arm64
package websocket
func mask(b []byte, key uint32) uint32 {
// TODO: Will enable in v1.9.0.
return maskGo(b, key)
/*
if len(b) > 0 {
return maskAsm(&b[0], len(b), key)
}
return key
*/
}
// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this
// function are perfect. There are almost certainly missing optimizations or
// opportunities for simplification. I'm confident there are no bugs though.
// For example, the arm64 implementation doesn't align memory like the amd64.
// Or the amd64 implementation could use AVX512 instead of just AVX2.
// The AVX2 code I had to disable anyway as it wasn't performing as expected.
// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049
//
//go:noescape
//lint:ignore U1000 disabled till v1.9.0
func maskAsm(b *byte, len int, key uint32) uint32
//go:build amd64 || arm64
package websocket
import "testing"
func TestMaskASM(t *testing.T) {
t.Parallel()
testMask(t, "maskASM", mask)
}
//go:build !amd64 && !arm64 && !js
package websocket
func mask(b []byte, key uint32) uint32 {
return maskGo(b, key)
}
package websocket
import (
"bytes"
"crypto/rand"
"encoding/binary"
"math/big"
"math/bits"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func basicMask(b []byte, key uint32) uint32 {
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
func basicMask2(b []byte, key uint32) uint32 {
keyb := binary.LittleEndian.AppendUint32(nil, key)
pos := 0
for i := range b {
b[i] ^= keyb[pos&3]
pos++
}
return bits.RotateLeft32(key, (pos&3)*-8)
}
func TestMask(t *testing.T) {
t.Parallel()
testMask(t, "basicMask", basicMask)
testMask(t, "maskGo", maskGo)
testMask(t, "basicMask2", basicMask2)
}
func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) {
t.Run(name, func(t *testing.T) {
t.Parallel()
for i := 0; i < 9999; i++ {
keyb := make([]byte, 4)
_, err := rand.Read(keyb)
assert.Success(t, err)
key := binary.LittleEndian.Uint32(keyb)
n, err := rand.Int(rand.Reader, big.NewInt(1<<16))
assert.Success(t, err)
b := make([]byte, 1+n.Int64())
_, err = rand.Read(b)
assert.Success(t, err)
b2 := make([]byte, len(b))
copy(b2, b)
b3 := make([]byte, len(b))
copy(b3, b)
key2 := basicMask(b2, key)
key3 := fn(b3, key)
if key2 != key3 {
t.Errorf("expected key %X but got %X", key2, key3)
}
if !bytes.Equal(b2, b3) {
t.Error("bad bytes")
return
}
}
})
}
......@@ -68,7 +68,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
defer nc.writeMu.unlock()
// Prevents future writes from writing until the deadline is reset.
atomic.StoreInt64(&nc.writeExpired, 1)
nc.writeExpired.Store(1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
......@@ -84,7 +84,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
defer nc.readMu.unlock()
// Prevents future reads from reading until the deadline is reset.
atomic.StoreInt64(&nc.readExpired, 1)
nc.readExpired.Store(1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
......@@ -99,13 +99,13 @@ type netConn struct {
writeTimer *time.Timer
writeMu *mu
writeExpired int64
writeExpired atomic.Int64
writeCtx context.Context
writeCancel context.CancelFunc
readTimer *time.Timer
readMu *mu
readExpired int64
readExpired atomic.Int64
readCtx context.Context
readCancel context.CancelFunc
readEOFed bool
......@@ -126,7 +126,7 @@ func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()
if atomic.LoadInt64(&nc.writeExpired) == 1 {
if nc.writeExpired.Load() == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}
......@@ -141,7 +141,20 @@ func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()
if atomic.LoadInt64(&nc.readExpired) == 1 {
for {
n, err := nc.read(p)
if err != nil {
return n, err
}
if n == 0 {
continue
}
return n, nil
}
}
func (nc *netConn) read(p []byte) (int, error) {
if nc.readExpired.Load() == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}
......@@ -193,21 +206,29 @@ func (nc *netConn) SetDeadline(t time.Time) error {
}
func (nc *netConn) SetWriteDeadline(t time.Time) error {
atomic.StoreInt64(&nc.writeExpired, 0)
nc.writeExpired.Store(0)
if t.IsZero() {
nc.writeTimer.Stop()
} else {
nc.writeTimer.Reset(time.Until(t))
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.writeTimer.Reset(dur)
}
return nil
}
func (nc *netConn) SetReadDeadline(t time.Time) error {
atomic.StoreInt64(&nc.readExpired, 0)
nc.readExpired.Store(0)
if t.IsZero() {
nc.readTimer.Stop()
} else {
nc.readTimer.Reset(time.Until(t))
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.readTimer.Reset(dur)
}
return nil
}
......@@ -9,11 +9,13 @@ import (
"errors"
"fmt"
"io"
"net"
"strings"
"sync/atomic"
"time"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/xsync"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Reader reads from the connection until there is a WebSocket
......@@ -58,12 +60,28 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
//
// This function is idempotent.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadDone = make(chan struct{})
c.closeReadMu.Unlock()
go func() {
defer close(c.closeReadDone)
defer cancel()
c.Reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
defer c.close()
_, _, err := c.Reader(ctx)
if err == nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
......@@ -101,13 +119,20 @@ func newMsgReader(c *Conn) *msgReader {
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
......@@ -122,7 +147,10 @@ func (mr *msgReader) putFlateReader() {
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
mr.dict.close()
if mr.dict != nil {
mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
......@@ -189,60 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
}
}
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
// prepareRead sets the readTimeout context and returns a done function
// to be called after the read is done. It also returns an error if the
// connection is closed. The reference to the error is used to assign
// an error depending on if the connection closed or the context timed
// out during use. Typically the referenced error is a named return
// variable of the function calling this method.
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
select {
case <-c.closed:
return header{}, c.closeErr
return nil, net.ErrClosed
case c.readTimeout <- ctx:
}
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
done := func() {
select {
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
return header{}, ctx.Err()
default:
c.close(err)
return header{}, err
if *err != nil {
*err = net.ErrClosed
}
case c.readTimeout <- context.Background():
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}
select {
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- context.Background():
c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
defer done()
return nil, closeReceivedErr
}
return h, nil
return done, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-c.closed:
return 0, c.closeErr
case c.readTimeout <- ctx:
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return header{}, err
}
defer readDone()
n, err := io.ReadFull(c.br, p)
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
select {
case <-c.closed:
return n, c.closeErr
case <-ctx.Done():
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
c.close(err)
return n, err
}
return header{}, err
}
select {
case <-c.closed:
return n, c.closeErr
case c.readTimeout <- context.Background():
return h, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return 0, err
}
defer readDone()
n, err := io.ReadFull(c.br, p)
if err != nil {
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
return n, err
......@@ -271,13 +307,21 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
}
if h.masked {
mask(h.maskKey, b)
mask(b, h.maskKey)
}
switch h.opcode {
case opPing:
if c.onPingReceived != nil {
if !c.onPingReceived(ctx, b) {
return nil
}
}
return c.writeControl(ctx, opPong, b)
case opPong:
if c.onPongReceived != nil {
c.onPongReceived(ctx, b)
}
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
......@@ -290,9 +334,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
return nil
}
defer func() {
c.readCloseFrameErr = err
}()
// opClose
ce, err := parseClosePayload(b)
if err != nil {
......@@ -302,9 +344,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
}
err = fmt.Errorf("received close frame: %w", ce)
c.setCloseErr(err)
c.writeClose(ce.Code, ce.Reason)
c.close(err)
c.closeStateMu.Lock()
c.closeReceivedErr = err
closeSent := c.closeSentErr != nil
c.closeStateMu.Unlock()
// Only unlock readMu if this connection is being closed becaue
// c.close will try to acquire the readMu lock. We unlock for
// writeClose as well because it may also call c.close.
if !closeSent {
c.readMu.unlock()
_ = c.writeClose(ce.Code, ce.Reason)
}
if !c.casClosing() {
c.readMu.unlock()
_ = c.close()
}
return err
}
......@@ -318,9 +373,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
defer c.readMu.unlock()
if !c.msgReader.fin {
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
return 0, nil, errors.New("previous message not read to completion")
}
h, err := c.readLoop(ctx)
......@@ -348,14 +401,14 @@ type msgReader struct {
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict slidingWindow
dict *slidingWindow
fin bool
payloadLength int64
maskKey uint32
// readerFunc(mr.Read) to avoid continuous allocations.
readFunc readerFunc
// util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc util.ReaderFunc
}
func (mr *msgReader) reset(ctx context.Context, h header) {
......@@ -393,10 +446,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
return n, fmt.Errorf("failed to read: %w", err)
}
return n, err
return n, nil
}
func (mr *msgReader) read(p []byte) (int, error) {
......@@ -435,7 +487,7 @@ func (mr *msgReader) read(p []byte) (int, error) {
mr.payloadLength -= int64(n)
if !mr.c.client {
mr.maskKey = mask(mr.maskKey, p)
mr.maskKey = mask(p, mr.maskKey)
}
return n, nil
......@@ -445,7 +497,7 @@ func (mr *msgReader) read(p []byte) (int, error) {
type limitReader struct {
c *Conn
r io.Reader
limit xsync.Int64
limit atomic.Int64
n int64
}
......@@ -484,9 +536,3 @@ func (lr *limitReader) Read(p []byte) (int, error) {
}
return n, err
}
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
......@@ -5,17 +5,18 @@ package websocket
import (
"bufio"
"compress/flate"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"compress/flate"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Writer returns a writer bounded by the context that will write
......@@ -37,7 +38,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the threshold is not met, then it
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
......@@ -48,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
}
type msgWriter struct {
mw *msgWriterState
closed bool
}
func (mw *msgWriter) Write(p []byte) (int, error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
return mw.mw.Write(p)
}
func (mw *msgWriter) Close() error {
if mw.closed {
return errors.New("cannot use closed writer")
}
mw.closed = true
return mw.mw.Close()
}
type msgWriterState struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context.Context
opcode opcode
......@@ -81,8 +63,8 @@ type msgWriterState struct {
flateWriter *flate.Writer
}
func newMsgWriterState(c *Conn) *msgWriterState {
mw := &msgWriterState{
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
......@@ -90,10 +72,10 @@ func newMsgWriterState(c *Conn) *msgWriterState {
return mw
}
func (mw *msgWriterState) ensureFlate() {
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
w: util.WriterFunc(mw.write),
}
}
......@@ -103,7 +85,7 @@ func (mw *msgWriterState) ensureFlate() {
mw.flate = true
}
func (mw *msgWriterState) flateContextTakeover() bool {
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
......@@ -111,14 +93,11 @@ func (mw *msgWriterState) flateContextTakeover() bool {
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriterState.reset(ctx, typ)
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return &msgWriter{
mw: c.msgWriterState,
closed: false,
}, nil
return c.msgWriter, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
......@@ -128,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
}
if !c.flate() {
defer c.msgWriterState.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}
n, err := mw.Write(p)
......@@ -141,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
return n, err
}
func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
......@@ -150,13 +129,14 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false
mw.trimWriter.reset()
return nil
}
func (mw *msgWriterState) putFlateWriter() {
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
......@@ -164,17 +144,20 @@ func (mw *msgWriterState) putFlateWriter() {
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.c.close(err)
}
}()
......@@ -193,7 +176,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
return mw.write(p)
}
func (mw *msgWriterState) write(p []byte) (int, error) {
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
......@@ -203,7 +186,7 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
}
// Close flushes the frame to the connection.
func (mw *msgWriterState) Close() (err error) {
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
......@@ -212,6 +195,11 @@ func (mw *msgWriterState) Close() (err error) {
}
defer mw.writeMu.unlock()
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
......@@ -231,7 +219,7 @@ func (mw *msgWriterState) Close() (err error) {
return nil
}
func (mw *msgWriterState) close() {
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
......@@ -252,48 +240,44 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
return nil
}
// frame handles all writes to the connection.
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
// If the state says a close has already been written, we wait until
// the connection is closed and return that error.
//
// However, if the frame being written is a close, that means its the close from
// the state being set so we let it go through.
c.closeMu.Lock()
wroteClose := c.wroteClose
c.closeMu.Unlock()
if wroteClose && opcode != opClose {
c.writeFrameMu.unlock()
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-c.closed:
return 0, c.closeErr
defer func() {
if c.isClosed() && opcode == opClose {
err = nil
}
if err != nil {
if ctx.Err() != nil {
err = ctx.Err()
} else if c.isClosed() {
err = net.ErrClosed
}
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.closeStateMu.Lock()
closeSentErr := c.closeSentErr
c.closeStateMu.Unlock()
if closeSentErr != nil {
return 0, net.ErrClosed
}
defer c.writeFrameMu.unlock()
select {
case <-c.closed:
return 0, c.closeErr
return 0, net.ErrClosed
case c.writeTimeout <- ctx:
}
defer func() {
if err != nil {
select {
case <-c.closed:
err = c.closeErr
case <-ctx.Done():
err = ctx.Err()
}
c.close(err)
err = fmt.Errorf("failed to write frame: %w", err)
select {
case <-c.closed:
case c.writeTimeout <- context.Background():
}
}()
......@@ -332,10 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
}
}
select {
case <-c.closed:
return n, c.closeErr
case c.writeTimeout <- context.Background():
if opcode == opClose {
c.closeStateMu.Lock()
c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
closeReceived := c.closeReceivedErr != nil
c.closeStateMu.Unlock()
if closeReceived && !c.casClosing() {
c.writeFrameMu.unlock()
_ = c.close()
}
}
return n, nil
......@@ -371,7 +361,7 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
return n, err
}
maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
p = p[j:]
n += j
......@@ -380,17 +370,11 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
return n, nil
}
type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(writerFunc(func(p2 []byte) (int, error) {
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
......@@ -404,7 +388,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
}
func (c *Conn) writeError(code StatusCode, err error) {
c.setCloseErr(err)
c.writeClose(code, err.Error())
c.close(nil)
}
This diff is collapsed.
......@@ -7,9 +7,9 @@ import (
"testing"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
)
func TestWasm(t *testing.T) {
......
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "nhooyr.io/websocket/wsjson"
package wsjson // import "github.com/coder/websocket/wsjson"
import (
"context"
"encoding/json"
"fmt"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/util"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Read reads a JSON message from c into v.
......
package wsjson_test
import (
"encoding/json"
"io"
"strconv"
"testing"
"github.com/coder/websocket/internal/test/xrand"
)
func BenchmarkJSON(b *testing.B) {
sizes := []int{
8,
16,
32,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
}
b.Run("json.Encoder", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.NewEncoder(io.Discard).Encode(msg)
}
})
}
})
b.Run("json.Marshal", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.Marshal(msg)
}
})
}
})
}