diff --git a/mask.go b/mask.go deleted file mode 100644 index 6c67f1c0a4c384f14c0d7746be82c8b8a35523fe..0000000000000000000000000000000000000000 --- a/mask.go +++ /dev/null @@ -1,17 +0,0 @@ -package websocket - -// mask applies the WebSocket masking algorithm to p -// with the given key where the first 3 bits of pos -// are the starting position in the key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the position of the next byte -// to be used for masking in the key. This is so that -// unmasking can be performed without the entire frame. -func mask(key [4]byte, pos int, p []byte) int { - for i := range p { - p[i] ^= key[pos&3] - pos++ - } - return pos & 3 -} diff --git a/websocket.go b/websocket.go index 287bf3e2862aa9c36771c7f3025bc2a1f50d8965..2f324d3acb3cbe8cc67528981729aa6379050524 100644 --- a/websocket.go +++ b/websocket.go @@ -231,7 +231,7 @@ func (c *Conn) handleControl(h header) { } if h.masked { - mask(h.maskKey, 0, b) + xor(h.maskKey, 0, b) } switch h.opcode { @@ -325,7 +325,7 @@ func (c *Conn) dataReadLoop(h header) (err error) { left -= int64(len(b)) if h.masked { - maskPos = mask(h.maskKey, maskPos, b) + maskPos = xor(h.maskKey, maskPos, b) } // Must set this before we signal the read is done. diff --git a/xor.go b/xor.go new file mode 100644 index 0000000000000000000000000000000000000000..1422f847d737248b82d8247a932af23f708575fc --- /dev/null +++ b/xor.go @@ -0,0 +1,42 @@ +package websocket + +import ( + "encoding/binary" +) + +// xor applies the WebSocket masking algorithm to p +// with the given key where the first 3 bits of pos +// are the starting position in the key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the position of the next byte +// to be used for masking in the key. This is so that +// unmasking can be performed without the entire frame. +func xor(key [4]byte, keyPos int, b []byte) int { + // If the payload is greater than 16 bytes, then it's worth + // masking 8 bytes at a time. + // Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859 + if len(b) > 16 { + // We first create a key that is 8 bytes long + // and is aligned on the position correctly. + var alignedKey [8]byte + for i := range alignedKey { + alignedKey[i] = key[(i+keyPos)&3] + } + k := binary.LittleEndian.Uint64(alignedKey[:]) + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^k) + b = b[8:] + } + } + + // xor remaining bytes. + for i := range b { + b[i] ^= key[keyPos&3] + keyPos++ + } + return keyPos & 3 +} diff --git a/mask_test.go b/xor_test.go similarity index 87% rename from mask_test.go rename to xor_test.go index 4a7b8c73c3783bfc074630dd27a47ce64be72c17..f715eda14c3e15d24c5e2992d13e2e9c41ea91ce 100644 --- a/mask_test.go +++ b/xor_test.go @@ -6,13 +6,13 @@ import ( "github.com/google/go-cmp/cmp" ) -func Test_mask(t *testing.T) { +func Test_xor(t *testing.T) { t.Parallel() key := [4]byte{0xa, 0xb, 0xc, 0xff} p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} pos := 0 - pos = mask(key, pos, p) + pos = xor(key, pos, p) if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p))