diff --git a/go.mod b/go.mod
index 60d1a3d0f432ec36f998138b76bd45979c55848d..3bce5bd6e874e22fb55f78152c240f1243961f91 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module nhooyr.io/websocket
 go 1.12
 
 require (
+	github.com/google/go-cmp v0.2.0
 	github.com/kr/pretty v0.1.0 // indirect
 	go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16
 	golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3
diff --git a/go.sum b/go.sum
index efb01309f0333453695cddc3df942598a95c1fdf..9b7bf490dd3a7b9c3a610b3a3a826842b21c3dfa 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=
+github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
 github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
 github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
diff --git a/mask.go b/mask.go
index 6569bb283004b9037e9243eccea88c200a4cc8dd..6c67f1c0a4c384f14c0d7746be82c8b8a35523fe 100644
--- a/mask.go
+++ b/mask.go
@@ -6,7 +6,8 @@ package websocket
 // See https://tools.ietf.org/html/rfc6455#section-5.3
 //
 // The returned value is the position of the next byte
-// to be used for masking in the key.
+// to be used for masking in the key. This is so that
+// unmasking can be performed without the entire frame.
 func mask(key [4]byte, pos int, p []byte) int {
 	for i := range p {
 		p[i] ^= key[pos&3]
diff --git a/mask_test.go b/mask_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..4a7b8c73c3783bfc074630dd27a47ce64be72c17
--- /dev/null
+++ b/mask_test.go
@@ -0,0 +1,24 @@
+package websocket
+
+import (
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func Test_mask(t *testing.T) {
+	t.Parallel()
+
+	key := [4]byte{0xa, 0xb, 0xc, 0xff}
+	p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
+	pos := 0
+	pos = mask(key, pos, p)
+
+	if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) {
+		t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p))
+	}
+
+	if exp := 1; !cmp.Equal(exp, pos) {
+		t.Fatalf("unexpected mask pos: %v", cmp.Diff(exp, pos))
+	}
+}