diff --git a/conn.go b/conn.go index df3a4044e37799441ed218d94ce31312a8238b79..90a5a6a1595c112514835d26149ab20bd9768220 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ import ( "bufio" "context" "crypto/rand" + "encoding/binary" "errors" "fmt" "io" @@ -81,7 +82,7 @@ type Conn struct { readerMsgCtx context.Context readerMsgHeader header readerFrameEOF bool - readerMaskPos int + readerMaskKey uint32 setReadTimeout chan context.Context setWriteTimeout chan context.Context @@ -324,7 +325,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { } if h.masked { - fastXOR(h.maskKey, 0, b) + fastXOR(h.maskKey, b) } switch h.opcode { @@ -445,8 +446,8 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e c.readerMsgCtx = ctx c.readerMsgHeader = h + c.readerMaskKey = h.maskKey c.readerFrameEOF = false - c.readerMaskPos = 0 c.readMsgLeft = c.msgReadLimit.Load() r := &messageReader{ @@ -532,7 +533,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) { r.c.readerMsgHeader = h r.c.readerFrameEOF = false - r.c.readerMaskPos = 0 + r.c.readerMaskKey = h.maskKey } h := r.c.readerMsgHeader @@ -545,7 +546,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) { h.payloadLength -= int64(n) r.c.readMsgLeft -= int64(n) if h.masked { - r.c.readerMaskPos = fastXOR(h.maskKey, r.c.readerMaskPos, p) + r.c.readerMaskKey = fastXOR(r.c.readerMaskKey, p) } r.c.readerMsgHeader = h @@ -761,7 +762,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte c.writeHeader.payloadLength = int64(len(p)) if c.client { - _, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:]) + err = binary.Read(rand.Reader, binary.BigEndian, &c.writeHeader.maskKey) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } @@ -809,7 +810,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e } if c.client { - var keypos int + maskKey := h.maskKey for len(p) > 0 { if c.bw.Available() == 0 { err = c.bw.Flush() @@ -831,7 +832,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e return n, err } - keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2]) + maskKey = fastXOR(maskKey, c.writeBuf[i:i+n2]) p = p[n2:] n += n2 diff --git a/conn_export_test.go b/conn_export_test.go index 94195a9c86f2e9df8cec08eb3ce0b2154dd98622..9335381c725e238486b38e2779ef2833a06a0b66 100644 --- a/conn_export_test.go +++ b/conn_export_test.go @@ -37,7 +37,7 @@ func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) { return 0, nil, err } if h.masked { - fastXOR(h.maskKey, 0, b) + fastXOR(h.maskKey, b) } return OpCode(h.opcode), b, nil } diff --git a/frame.go b/frame.go index be23330e532c0cc9fede81521b266a7c49868a49..5345d5168e0e6ca9853164e44a25ccbc88060d3c 100644 --- a/frame.go +++ b/frame.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math" + "math/bits" ) //go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go @@ -69,7 +70,7 @@ type header struct { payloadLength int64 masked bool - maskKey [4]byte + maskKey uint32 } func makeWriteHeaderBuf() []byte { @@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte { if h.masked { b[1] |= 1 << 7 b = b[:len(b)+4] - copy(b[len(b)-4:], h.maskKey[:]) + binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) } return b @@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) { } if h.masked { - copy(h.maskKey[:], b) + h.maskKey = binary.LittleEndian.Uint32(b) } return h, nil @@ -321,26 +322,18 @@ func (ce CloseError) bytes() ([]byte, error) { return buf, nil } -// xor applies the WebSocket masking algorithm to p -// with the given key where the first 3 bits of pos -// are the starting position in the key. +// fastXOR applies the WebSocket masking algorithm to p +// with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // -// The returned value is the position of the next byte -// to be used for masking in the key. This is so that -// unmasking can be performed without the entire frame. -func fastXOR(key [4]byte, keyPos int, b []byte) int { - // If the payload is greater than or equal to 16 bytes, then it's worth - // masking 8 bytes at a time. - // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 - if len(b) >= 16 { - // We first create a key that is 8 bytes long - // and is aligned on the position correctly. - var alignedKey [8]byte - for i := range alignedKey { - alignedKey[i] = key[(i+keyPos)&3] - } - k := binary.LittleEndian.Uint64(alignedKey[:]) +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +func fastXOR(key uint32, b []byte) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) // At some point in the future we can clean these unrolled loops up. // See https://github.com/golang/go/issues/31586#issuecomment-487436401 @@ -348,95 +341,103 @@ func fastXOR(key [4]byte, keyPos int, b []byte) int { // Then we xor until b is less than 128 bytes. for len(b) >= 128 { v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) + binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) + binary.LittleEndian.PutUint64(b[8:], v^key64) v = binary.LittleEndian.Uint64(b[16:]) - binary.LittleEndian.PutUint64(b[16:], v^k) + binary.LittleEndian.PutUint64(b[16:], v^key64) v = binary.LittleEndian.Uint64(b[24:]) - binary.LittleEndian.PutUint64(b[24:], v^k) + binary.LittleEndian.PutUint64(b[24:], v^key64) v = binary.LittleEndian.Uint64(b[32:]) - binary.LittleEndian.PutUint64(b[32:], v^k) + binary.LittleEndian.PutUint64(b[32:], v^key64) v = binary.LittleEndian.Uint64(b[40:]) - binary.LittleEndian.PutUint64(b[40:], v^k) + binary.LittleEndian.PutUint64(b[40:], v^key64) v = binary.LittleEndian.Uint64(b[48:]) - binary.LittleEndian.PutUint64(b[48:], v^k) + binary.LittleEndian.PutUint64(b[48:], v^key64) v = binary.LittleEndian.Uint64(b[56:]) - binary.LittleEndian.PutUint64(b[56:], v^k) + binary.LittleEndian.PutUint64(b[56:], v^key64) v = binary.LittleEndian.Uint64(b[64:]) - binary.LittleEndian.PutUint64(b[64:], v^k) + binary.LittleEndian.PutUint64(b[64:], v^key64) v = binary.LittleEndian.Uint64(b[72:]) - binary.LittleEndian.PutUint64(b[72:], v^k) + binary.LittleEndian.PutUint64(b[72:], v^key64) v = binary.LittleEndian.Uint64(b[80:]) - binary.LittleEndian.PutUint64(b[80:], v^k) + binary.LittleEndian.PutUint64(b[80:], v^key64) v = binary.LittleEndian.Uint64(b[88:]) - binary.LittleEndian.PutUint64(b[88:], v^k) + binary.LittleEndian.PutUint64(b[88:], v^key64) v = binary.LittleEndian.Uint64(b[96:]) - binary.LittleEndian.PutUint64(b[96:], v^k) + binary.LittleEndian.PutUint64(b[96:], v^key64) v = binary.LittleEndian.Uint64(b[104:]) - binary.LittleEndian.PutUint64(b[104:], v^k) + binary.LittleEndian.PutUint64(b[104:], v^key64) v = binary.LittleEndian.Uint64(b[112:]) - binary.LittleEndian.PutUint64(b[112:], v^k) + binary.LittleEndian.PutUint64(b[112:], v^key64) v = binary.LittleEndian.Uint64(b[120:]) - binary.LittleEndian.PutUint64(b[120:], v^k) + binary.LittleEndian.PutUint64(b[120:], v^key64) b = b[128:] } // Then we xor until b is less than 64 bytes. for len(b) >= 64 { v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) + binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) + binary.LittleEndian.PutUint64(b[8:], v^key64) v = binary.LittleEndian.Uint64(b[16:]) - binary.LittleEndian.PutUint64(b[16:], v^k) + binary.LittleEndian.PutUint64(b[16:], v^key64) v = binary.LittleEndian.Uint64(b[24:]) - binary.LittleEndian.PutUint64(b[24:], v^k) + binary.LittleEndian.PutUint64(b[24:], v^key64) v = binary.LittleEndian.Uint64(b[32:]) - binary.LittleEndian.PutUint64(b[32:], v^k) + binary.LittleEndian.PutUint64(b[32:], v^key64) v = binary.LittleEndian.Uint64(b[40:]) - binary.LittleEndian.PutUint64(b[40:], v^k) + binary.LittleEndian.PutUint64(b[40:], v^key64) v = binary.LittleEndian.Uint64(b[48:]) - binary.LittleEndian.PutUint64(b[48:], v^k) + binary.LittleEndian.PutUint64(b[48:], v^key64) v = binary.LittleEndian.Uint64(b[56:]) - binary.LittleEndian.PutUint64(b[56:], v^k) + binary.LittleEndian.PutUint64(b[56:], v^key64) b = b[64:] } // Then we xor until b is less than 32 bytes. for len(b) >= 32 { v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) + binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) + binary.LittleEndian.PutUint64(b[8:], v^key64) v = binary.LittleEndian.Uint64(b[16:]) - binary.LittleEndian.PutUint64(b[16:], v^k) + binary.LittleEndian.PutUint64(b[16:], v^key64) v = binary.LittleEndian.Uint64(b[24:]) - binary.LittleEndian.PutUint64(b[24:], v^k) + binary.LittleEndian.PutUint64(b[24:], v^key64) b = b[32:] } // Then we xor until b is less than 16 bytes. for len(b) >= 16 { v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) + binary.LittleEndian.PutUint64(b, v^key64) v = binary.LittleEndian.Uint64(b[8:]) - binary.LittleEndian.PutUint64(b[8:], v^k) + binary.LittleEndian.PutUint64(b[8:], v^key64) b = b[16:] } // Then we xor until b is less than 8 bytes. for len(b) >= 8 { v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^k) + binary.LittleEndian.PutUint64(b, v^key64) b = b[8:] } } + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + // xor remaining bytes. for i := range b { - b[i] ^= key[keyPos&3] - keyPos++ + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) } - return keyPos & 3 + + return key } diff --git a/frame_test.go b/frame_test.go index 84742ff01ee07d0f50698fc9dc4002c6babac740..c8f4cd8d59f95c036aa46b97192133e57674f21d 100644 --- a/frame_test.go +++ b/frame_test.go @@ -4,8 +4,10 @@ package websocket import ( "bytes" + "encoding/binary" "io" "math" + "math/bits" "math/rand" "strconv" "strings" @@ -133,7 +135,7 @@ func TestHeader(t *testing.T) { } if h.masked { - rand.Read(h.maskKey[:]) + h.maskKey = rand.Uint32() } testHeader(t, h) @@ -309,17 +311,17 @@ func Test_validWireCloseCode(t *testing.T) { func Test_xor(t *testing.T) { t.Parallel() - key := [4]byte{0xa, 0xb, 0xc, 0xff} + key := []byte{0xa, 0xb, 0xc, 0xff} + key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - pos := 0 - pos = fastXOR(key, pos, p) + gotKey32 := fastXOR(key32, p) if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) } - if exp := 1; !cmp.Equal(exp, pos) { - t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos)) + if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { + t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) } } @@ -347,26 +349,37 @@ func BenchmarkXOR(b *testing.B) { fns := []struct { name string - fn func([4]byte, int, []byte) int + fn func(b *testing.B, key [4]byte, p []byte) }{ { - "basic", - basixXOR, + name: "basic", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + basixXOR(key, 0, p) + } + }, }, { - "fast", - fastXOR, + name: "fast", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.BigEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + fastXOR(key32, p) + } + }, }, } - var maskKey [4]byte - _, err := rand.Read(maskKey[:]) + var key [4]byte + _, err := rand.Read(key[:]) if err != nil { b.Fatalf("failed to populate mask key: %v", err) } for _, size := range sizes { - data := make([]byte, size) + p := make([]byte, size) b.Run(strconv.Itoa(size), func(b *testing.B) { for _, fn := range fns { @@ -374,9 +387,7 @@ func BenchmarkXOR(b *testing.B) { b.ReportAllocs() b.SetBytes(int64(size)) - for i := 0; i < b.N; i++ { - fn.fn(maskKey, 0, data) - } + fn.fn(b, key, p) }) } })