diff --git a/conn.go b/conn.go
index df3a4044e37799441ed218d94ce31312a8238b79..90a5a6a1595c112514835d26149ab20bd9768220 100644
--- a/conn.go
+++ b/conn.go
@@ -6,6 +6,7 @@ import (
 	"bufio"
 	"context"
 	"crypto/rand"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -81,7 +82,7 @@ type Conn struct {
 	readerMsgCtx    context.Context
 	readerMsgHeader header
 	readerFrameEOF  bool
-	readerMaskPos   int
+	readerMaskKey   uint32
 
 	setReadTimeout  chan context.Context
 	setWriteTimeout chan context.Context
@@ -324,7 +325,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 	}
 
 	if h.masked {
-		fastXOR(h.maskKey, 0, b)
+		fastXOR(h.maskKey, b)
 	}
 
 	switch h.opcode {
@@ -445,8 +446,8 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
 
 	c.readerMsgCtx = ctx
 	c.readerMsgHeader = h
+	c.readerMaskKey = h.maskKey
 	c.readerFrameEOF = false
-	c.readerMaskPos = 0
 	c.readMsgLeft = c.msgReadLimit.Load()
 
 	r := &messageReader{
@@ -532,7 +533,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) {
 
 		r.c.readerMsgHeader = h
 		r.c.readerFrameEOF = false
-		r.c.readerMaskPos = 0
+		r.c.readerMaskKey = h.maskKey
 	}
 
 	h := r.c.readerMsgHeader
@@ -545,7 +546,7 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) {
 	h.payloadLength -= int64(n)
 	r.c.readMsgLeft -= int64(n)
 	if h.masked {
-		r.c.readerMaskPos = fastXOR(h.maskKey, r.c.readerMaskPos, p)
+		r.c.readerMaskKey = fastXOR(r.c.readerMaskKey, p)
 	}
 	r.c.readerMsgHeader = h
 
@@ -761,7 +762,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
 	c.writeHeader.payloadLength = int64(len(p))
 
 	if c.client {
-		_, err := io.ReadFull(rand.Reader, c.writeHeader.maskKey[:])
+		err = binary.Read(rand.Reader, binary.BigEndian, &c.writeHeader.maskKey)
 		if err != nil {
 			return 0, fmt.Errorf("failed to generate masking key: %w", err)
 		}
@@ -809,7 +810,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
 	}
 
 	if c.client {
-		var keypos int
+		maskKey := h.maskKey
 		for len(p) > 0 {
 			if c.bw.Available() == 0 {
 				err = c.bw.Flush()
@@ -831,7 +832,7 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e
 				return n, err
 			}
 
-			keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])
+			maskKey = fastXOR(maskKey, c.writeBuf[i:i+n2])
 
 			p = p[n2:]
 			n += n2
diff --git a/conn_export_test.go b/conn_export_test.go
index 94195a9c86f2e9df8cec08eb3ce0b2154dd98622..9335381c725e238486b38e2779ef2833a06a0b66 100644
--- a/conn_export_test.go
+++ b/conn_export_test.go
@@ -37,7 +37,7 @@ func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) {
 		return 0, nil, err
 	}
 	if h.masked {
-		fastXOR(h.maskKey, 0, b)
+		fastXOR(h.maskKey, b)
 	}
 	return OpCode(h.opcode), b, nil
 }
diff --git a/frame.go b/frame.go
index be23330e532c0cc9fede81521b266a7c49868a49..5345d5168e0e6ca9853164e44a25ccbc88060d3c 100644
--- a/frame.go
+++ b/frame.go
@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"math"
+	"math/bits"
 )
 
 //go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go
@@ -69,7 +70,7 @@ type header struct {
 	payloadLength int64
 
 	masked  bool
-	maskKey [4]byte
+	maskKey uint32
 }
 
 func makeWriteHeaderBuf() []byte {
@@ -119,7 +120,7 @@ func writeHeader(b []byte, h header) []byte {
 	if h.masked {
 		b[1] |= 1 << 7
 		b = b[:len(b)+4]
-		copy(b[len(b)-4:], h.maskKey[:])
+		binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey)
 	}
 
 	return b
@@ -192,7 +193,7 @@ func readHeader(b []byte, r io.Reader) (header, error) {
 	}
 
 	if h.masked {
-		copy(h.maskKey[:], b)
+		h.maskKey = binary.LittleEndian.Uint32(b)
 	}
 
 	return h, nil
@@ -321,26 +322,18 @@ func (ce CloseError) bytes() ([]byte, error) {
 	return buf, nil
 }
 
-// xor applies the WebSocket masking algorithm to p
-// with the given key where the first 3 bits of pos
-// are the starting position in the key.
+// fastXOR applies the WebSocket masking algorithm to p
+// with the given key.
 // See https://tools.ietf.org/html/rfc6455#section-5.3
 //
-// The returned value is the position of the next byte
-// to be used for masking in the key. This is so that
-// unmasking can be performed without the entire frame.
-func fastXOR(key [4]byte, keyPos int, b []byte) int {
-	// If the payload is greater than or equal to 16 bytes, then it's worth
-	// masking 8 bytes at a time.
-	// Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859
-	if len(b) >= 16 {
-		// We first create a key that is 8 bytes long
-		// and is aligned on the position correctly.
-		var alignedKey [8]byte
-		for i := range alignedKey {
-			alignedKey[i] = key[(i+keyPos)&3]
-		}
-		k := binary.LittleEndian.Uint64(alignedKey[:])
+// The returned value is the correctly rotated key to
+// to continue to mask/unmask the message.
+//
+// It is optimized for LittleEndian and expects the key
+// to be in little endian.
+func fastXOR(key uint32, b []byte) uint32 {
+	if len(b) >= 8 {
+		key64 := uint64(key)<<32 | uint64(key)
 
 		// At some point in the future we can clean these unrolled loops up.
 		// See https://github.com/golang/go/issues/31586#issuecomment-487436401
@@ -348,95 +341,103 @@ func fastXOR(key [4]byte, keyPos int, b []byte) int {
 		// Then we xor until b is less than 128 bytes.
 		for len(b) >= 128 {
 			v := binary.LittleEndian.Uint64(b)
-			binary.LittleEndian.PutUint64(b, v^k)
+			binary.LittleEndian.PutUint64(b, v^key64)
 			v = binary.LittleEndian.Uint64(b[8:])
-			binary.LittleEndian.PutUint64(b[8:], v^k)
+			binary.LittleEndian.PutUint64(b[8:], v^key64)
 			v = binary.LittleEndian.Uint64(b[16:])
-			binary.LittleEndian.PutUint64(b[16:], v^k)
+			binary.LittleEndian.PutUint64(b[16:], v^key64)
 			v = binary.LittleEndian.Uint64(b[24:])
-			binary.LittleEndian.PutUint64(b[24:], v^k)
+			binary.LittleEndian.PutUint64(b[24:], v^key64)
 			v = binary.LittleEndian.Uint64(b[32:])
-			binary.LittleEndian.PutUint64(b[32:], v^k)
+			binary.LittleEndian.PutUint64(b[32:], v^key64)
 			v = binary.LittleEndian.Uint64(b[40:])
-			binary.LittleEndian.PutUint64(b[40:], v^k)
+			binary.LittleEndian.PutUint64(b[40:], v^key64)
 			v = binary.LittleEndian.Uint64(b[48:])
-			binary.LittleEndian.PutUint64(b[48:], v^k)
+			binary.LittleEndian.PutUint64(b[48:], v^key64)
 			v = binary.LittleEndian.Uint64(b[56:])
-			binary.LittleEndian.PutUint64(b[56:], v^k)
+			binary.LittleEndian.PutUint64(b[56:], v^key64)
 			v = binary.LittleEndian.Uint64(b[64:])
-			binary.LittleEndian.PutUint64(b[64:], v^k)
+			binary.LittleEndian.PutUint64(b[64:], v^key64)
 			v = binary.LittleEndian.Uint64(b[72:])
-			binary.LittleEndian.PutUint64(b[72:], v^k)
+			binary.LittleEndian.PutUint64(b[72:], v^key64)
 			v = binary.LittleEndian.Uint64(b[80:])
-			binary.LittleEndian.PutUint64(b[80:], v^k)
+			binary.LittleEndian.PutUint64(b[80:], v^key64)
 			v = binary.LittleEndian.Uint64(b[88:])
-			binary.LittleEndian.PutUint64(b[88:], v^k)
+			binary.LittleEndian.PutUint64(b[88:], v^key64)
 			v = binary.LittleEndian.Uint64(b[96:])
-			binary.LittleEndian.PutUint64(b[96:], v^k)
+			binary.LittleEndian.PutUint64(b[96:], v^key64)
 			v = binary.LittleEndian.Uint64(b[104:])
-			binary.LittleEndian.PutUint64(b[104:], v^k)
+			binary.LittleEndian.PutUint64(b[104:], v^key64)
 			v = binary.LittleEndian.Uint64(b[112:])
-			binary.LittleEndian.PutUint64(b[112:], v^k)
+			binary.LittleEndian.PutUint64(b[112:], v^key64)
 			v = binary.LittleEndian.Uint64(b[120:])
-			binary.LittleEndian.PutUint64(b[120:], v^k)
+			binary.LittleEndian.PutUint64(b[120:], v^key64)
 			b = b[128:]
 		}
 
 		// Then we xor until b is less than 64 bytes.
 		for len(b) >= 64 {
 			v := binary.LittleEndian.Uint64(b)
-			binary.LittleEndian.PutUint64(b, v^k)
+			binary.LittleEndian.PutUint64(b, v^key64)
 			v = binary.LittleEndian.Uint64(b[8:])
-			binary.LittleEndian.PutUint64(b[8:], v^k)
+			binary.LittleEndian.PutUint64(b[8:], v^key64)
 			v = binary.LittleEndian.Uint64(b[16:])
-			binary.LittleEndian.PutUint64(b[16:], v^k)
+			binary.LittleEndian.PutUint64(b[16:], v^key64)
 			v = binary.LittleEndian.Uint64(b[24:])
-			binary.LittleEndian.PutUint64(b[24:], v^k)
+			binary.LittleEndian.PutUint64(b[24:], v^key64)
 			v = binary.LittleEndian.Uint64(b[32:])
-			binary.LittleEndian.PutUint64(b[32:], v^k)
+			binary.LittleEndian.PutUint64(b[32:], v^key64)
 			v = binary.LittleEndian.Uint64(b[40:])
-			binary.LittleEndian.PutUint64(b[40:], v^k)
+			binary.LittleEndian.PutUint64(b[40:], v^key64)
 			v = binary.LittleEndian.Uint64(b[48:])
-			binary.LittleEndian.PutUint64(b[48:], v^k)
+			binary.LittleEndian.PutUint64(b[48:], v^key64)
 			v = binary.LittleEndian.Uint64(b[56:])
-			binary.LittleEndian.PutUint64(b[56:], v^k)
+			binary.LittleEndian.PutUint64(b[56:], v^key64)
 			b = b[64:]
 		}
 
 		// Then we xor until b is less than 32 bytes.
 		for len(b) >= 32 {
 			v := binary.LittleEndian.Uint64(b)
-			binary.LittleEndian.PutUint64(b, v^k)
+			binary.LittleEndian.PutUint64(b, v^key64)
 			v = binary.LittleEndian.Uint64(b[8:])
-			binary.LittleEndian.PutUint64(b[8:], v^k)
+			binary.LittleEndian.PutUint64(b[8:], v^key64)
 			v = binary.LittleEndian.Uint64(b[16:])
-			binary.LittleEndian.PutUint64(b[16:], v^k)
+			binary.LittleEndian.PutUint64(b[16:], v^key64)
 			v = binary.LittleEndian.Uint64(b[24:])
-			binary.LittleEndian.PutUint64(b[24:], v^k)
+			binary.LittleEndian.PutUint64(b[24:], v^key64)
 			b = b[32:]
 		}
 
 		// Then we xor until b is less than 16 bytes.
 		for len(b) >= 16 {
 			v := binary.LittleEndian.Uint64(b)
-			binary.LittleEndian.PutUint64(b, v^k)
+			binary.LittleEndian.PutUint64(b, v^key64)
 			v = binary.LittleEndian.Uint64(b[8:])
-			binary.LittleEndian.PutUint64(b[8:], v^k)
+			binary.LittleEndian.PutUint64(b[8:], v^key64)
 			b = b[16:]
 		}
 
 		// Then we xor until b is less than 8 bytes.
 		for len(b) >= 8 {
 			v := binary.LittleEndian.Uint64(b)
-			binary.LittleEndian.PutUint64(b, v^k)
+			binary.LittleEndian.PutUint64(b, v^key64)
 			b = b[8:]
 		}
 	}
 
+	// Then we xor until b is less than 4 bytes.
+	for len(b) >= 4 {
+		v := binary.LittleEndian.Uint32(b)
+		binary.LittleEndian.PutUint32(b, v^key)
+		b = b[4:]
+	}
+
 	// xor remaining bytes.
 	for i := range b {
-		b[i] ^= key[keyPos&3]
-		keyPos++
+		b[i] ^= byte(key)
+		key = bits.RotateLeft32(key, -8)
 	}
-	return keyPos & 3
+
+	return key
 }
diff --git a/frame_test.go b/frame_test.go
index 84742ff01ee07d0f50698fc9dc4002c6babac740..c8f4cd8d59f95c036aa46b97192133e57674f21d 100644
--- a/frame_test.go
+++ b/frame_test.go
@@ -4,8 +4,10 @@ package websocket
 
 import (
 	"bytes"
+	"encoding/binary"
 	"io"
 	"math"
+	"math/bits"
 	"math/rand"
 	"strconv"
 	"strings"
@@ -133,7 +135,7 @@ func TestHeader(t *testing.T) {
 			}
 
 			if h.masked {
-				rand.Read(h.maskKey[:])
+				h.maskKey = rand.Uint32()
 			}
 
 			testHeader(t, h)
@@ -309,17 +311,17 @@ func Test_validWireCloseCode(t *testing.T) {
 func Test_xor(t *testing.T) {
 	t.Parallel()
 
-	key := [4]byte{0xa, 0xb, 0xc, 0xff}
+	key := []byte{0xa, 0xb, 0xc, 0xff}
+	key32 := binary.LittleEndian.Uint32(key)
 	p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
-	pos := 0
-	pos = fastXOR(key, pos, p)
+	gotKey32 := fastXOR(key32, p)
 
 	if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) {
 		t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p))
 	}
 
-	if exp := 1; !cmp.Equal(exp, pos) {
-		t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos))
+	if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) {
+		t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32))
 	}
 }
 
@@ -347,26 +349,37 @@ func BenchmarkXOR(b *testing.B) {
 
 	fns := []struct {
 		name string
-		fn   func([4]byte, int, []byte) int
+		fn   func(b *testing.B, key [4]byte, p []byte)
 	}{
 		{
-			"basic",
-			basixXOR,
+			name: "basic",
+			fn: func(b *testing.B, key [4]byte, p []byte) {
+				for i := 0; i < b.N; i++ {
+					basixXOR(key, 0, p)
+				}
+			},
 		},
 		{
-			"fast",
-			fastXOR,
+			name: "fast",
+			fn: func(b *testing.B, key [4]byte, p []byte) {
+				key32 := binary.BigEndian.Uint32(key[:])
+				b.ResetTimer()
+
+				for i := 0; i < b.N; i++ {
+					fastXOR(key32, p)
+				}
+			},
 		},
 	}
 
-	var maskKey [4]byte
-	_, err := rand.Read(maskKey[:])
+	var key [4]byte
+	_, err := rand.Read(key[:])
 	if err != nil {
 		b.Fatalf("failed to populate mask key: %v", err)
 	}
 
 	for _, size := range sizes {
-		data := make([]byte, size)
+		p := make([]byte, size)
 
 		b.Run(strconv.Itoa(size), func(b *testing.B) {
 			for _, fn := range fns {
@@ -374,9 +387,7 @@ func BenchmarkXOR(b *testing.B) {
 					b.ReportAllocs()
 					b.SetBytes(int64(size))
 
-					for i := 0; i < b.N; i++ {
-						fn.fn(maskKey, 0, data)
-					}
+					fn.fn(b, key, p)
 				})
 			}
 		})