diff --git a/.gitignore b/.gitignore
index 70d8e7030c7c59458ca2741bea53c61e7ff22715..35ecb6b04d5cfb5e5df8b468368ccb5ac941e294 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
 coverage.html
 wstest_reports
+websocket.test
diff --git a/accept.go b/accept.go
index 3120690a54b88cdd2519e1d599f519a52fb5ee3e..e0c31ef5cb29e7805ec236b7629beedbc95439be 100644
--- a/accept.go
+++ b/accept.go
@@ -53,19 +53,19 @@ func AcceptInsecureOrigin() AcceptOption {
 
 func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
 	if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") {
-		err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
+		err := xerrors.Errorf("websocket: protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		return err
 	}
 
 	if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") {
-		err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
+		err := xerrors.Errorf("websocket: protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		return err
 	}
 
 	if r.Method != "GET" {
-		err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
+		err := xerrors.Errorf("websocket: protocol violation: handshake request method %q is not GET", r.Method)
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		return err
 	}
@@ -88,7 +88,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
 // Accept accepts a WebSocket handshake from a client and upgrades the
 // the connection to WebSocket.
 // Accept will reject the handshake if the Origin is not the same as the Host unless
-// InsecureAcceptOrigin is passed.
+// the AcceptInsecureOrigin option is passed.
 // Accept uses w to write the handshake response so the timeouts on the http.Server apply.
 func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
 	var subprotocols []string
diff --git a/bench_test.go b/bench_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..66331e0c72e186950ec727e333e3dcb99aca8da5
--- /dev/null
+++ b/bench_test.go
@@ -0,0 +1,87 @@
+package websocket_test
+
+import (
+	"context"
+	"io"
+	"net/http"
+	"strconv"
+	"strings"
+	"testing"
+	"time"
+
+	"nhooyr.io/websocket"
+)
+
+func BenchmarkConn(b *testing.B) {
+	b.StopTimer()
+
+	s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		c, err := websocket.Accept(w, r,
+			websocket.AcceptSubprotocols("echo"),
+		)
+		if err != nil {
+			b.Logf("server handshake failed: %+v", err)
+			return
+		}
+		echoLoop(r.Context(), c)
+	}))
+	defer closeFn()
+
+	wsURL := strings.Replace(s.URL, "http", "ws", 1)
+
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
+	defer cancel()
+
+	c, _, err := websocket.Dial(ctx, wsURL)
+	if err != nil {
+		b.Fatalf("failed to dial: %v", err)
+	}
+	defer c.Close(websocket.StatusInternalError, "")
+
+	runN := func(n int) {
+		b.Run(strconv.Itoa(n), func(b *testing.B) {
+			msg := []byte(strings.Repeat("2", n))
+			buf := make([]byte, len(msg))
+			b.SetBytes(int64(len(msg)))
+			b.ResetTimer()
+			for i := 0; i < b.N; i++ {
+				w, err := c.Write(ctx, websocket.MessageText)
+				if err != nil {
+					b.Fatal(err)
+				}
+
+				_, err = w.Write(msg)
+				if err != nil {
+					b.Fatal(err)
+				}
+
+				err = w.Close()
+				if err != nil {
+					b.Fatal(err)
+				}
+
+				_, r, err := c.Read(ctx)
+				if err != nil {
+					b.Fatal(err, b.N)
+				}
+
+				_, err = io.ReadFull(r, buf)
+				if err != nil {
+					b.Fatal(err)
+				}
+			}
+			b.StopTimer()
+		})
+	}
+
+	runN(32)
+	runN(128)
+	runN(512)
+	runN(1024)
+	runN(4096)
+	runN(16384)
+	runN(65536)
+	runN(131072)
+
+	c.Close(websocket.StatusNormalClosure, "")
+}
diff --git a/dial_test.go b/dial_test.go
index 48c1c3125a4b33b5eaf4eb7db82e447d29e08c54..02aaa4fc874df6ed826027cfa9e26a52b82d9f2a 100644
--- a/dial_test.go
+++ b/dial_test.go
@@ -7,6 +7,8 @@ import (
 )
 
 func Test_verifyServerHandshake(t *testing.T) {
+	t.Parallel()
+
 	testCases := []struct {
 		name     string
 		response func(w http.ResponseWriter)
diff --git a/example_test.go b/example_test.go
index 702239b2afa5530610dbe85651f737fc50dc83e9..c343d78f3b86c956937e9506600a119668665113 100644
--- a/example_test.go
+++ b/example_test.go
@@ -34,10 +34,13 @@ func ExampleAccept_echo() {
 			if err != nil {
 				return err
 			}
-
 			r = io.LimitReader(r, 32768)
 
-			w := c.Write(ctx, typ)
+			w, err := c.Write(ctx, typ)
+			if err != nil {
+				return err
+			}
+
 			_, err = io.Copy(w, r)
 			if err != nil {
 				return err
@@ -76,7 +79,7 @@ func ExampleAccept() {
 			log.Printf("server handshake failed: %v", err)
 			return
 		}
-		defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error.
+		defer c.Close(websocket.StatusInternalError, "")
 
 		jc := websocket.JSONConn{
 			Conn: c,
diff --git a/header.go b/header.go
index 276fa0c30b93f6120c18d064c96b1c5e05548f1d..82ad5f561431e322fcdd402f9c3c6cfade64f49d 100644
--- a/header.go
+++ b/header.go
@@ -4,6 +4,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"io"
+	"math"
 
 	"golang.org/x/xerrors"
 )
@@ -55,7 +56,7 @@ func marshalHeader(h header) []byte {
 		panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength))
 	case h.payloadLength <= 125:
 		b[1] = byte(h.payloadLength)
-	case h.payloadLength <= 1<<16:
+	case h.payloadLength <= math.MaxUint16:
 		b[1] = 126
 		b = b[:len(b)+2]
 		binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength))
@@ -105,10 +106,8 @@ func readHeader(r io.Reader) (header, error) {
 	case payloadLength < 126:
 		h.payloadLength = int64(payloadLength)
 	case payloadLength == 126:
-		h.payloadLength = 126
 		extra += 2
 	case payloadLength == 127:
-		h.payloadLength = 127
 		extra += 8
 	}
 
diff --git a/header_test.go b/header_test.go
index b4d0769fb2c33044fd841c04041f58a8d8328028..b9cf351b93255e09c1036be4357af0d0c6f94ef1 100644
--- a/header_test.go
+++ b/header_test.go
@@ -3,6 +3,7 @@ package websocket
 import (
 	"bytes"
 	"math/rand"
+	"strconv"
 	"testing"
 	"time"
 
@@ -36,10 +37,38 @@ func TestHeader(t *testing.T) {
 			t.Fatalf("unexpected error value: %+v", err)
 		}
 	})
+
+	t.Run("lengths", func(t *testing.T) {
+		t.Parallel()
+
+		lengths := []int{
+			124,
+			125,
+			126,
+			4096,
+			16384,
+			65535,
+			65536,
+			65537,
+			131072,
+		}
+
+		for _, n := range lengths {
+			n := n
+			t.Run(strconv.Itoa(n), func(t *testing.T) {
+				t.Parallel()
+
+				testHeader(t, header{
+					payloadLength: int64(n),
+				})
+			})
+		}
+	})
+
 	t.Run("fuzz", func(t *testing.T) {
 		t.Parallel()
 
-		for i := 0; i < 1000; i++ {
+		for i := 0; i < 10000; i++ {
 			h := header{
 				fin:    randBool(),
 				rsv1:   randBool(),
@@ -55,20 +84,24 @@ func TestHeader(t *testing.T) {
 				rand.Read(h.maskKey[:])
 			}
 
-			b := marshalHeader(h)
-			r := bytes.NewReader(b)
-			h2, err := readHeader(r)
-			if err != nil {
-				t.Logf("header: %#v", h)
-				t.Logf("bytes: %b", b)
-				t.Fatalf("failed to read header: %v", err)
-			}
-
-			if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) {
-				t.Logf("header: %#v", h)
-				t.Logf("bytes: %b", b)
-				t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
-			}
+			testHeader(t, h)
 		}
 	})
 }
+
+func testHeader(t *testing.T, h header) {
+	b := marshalHeader(h)
+	r := bytes.NewReader(b)
+	h2, err := readHeader(r)
+	if err != nil {
+		t.Logf("header: %#v", h)
+		t.Logf("bytes: %b", b)
+		t.Fatalf("failed to read header: %v", err)
+	}
+
+	if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) {
+		t.Logf("header: %#v", h)
+		t.Logf("bytes: %b", b)
+		t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
+	}
+}
diff --git a/json.go b/json.go
index 24e6f3184c4aaeb795dc7305e9856249ae373ad2..0d85a5dbee9cd176be09a505dcf483c602f29fd9 100644
--- a/json.go
+++ b/json.go
@@ -22,7 +22,7 @@ func (jc JSONConn) Read(ctx context.Context, v interface{}) error {
 	return nil
 }
 
-func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
+func (jc JSONConn) read(ctx context.Context, v interface{}) error {
 	typ, r, err := jc.Conn.Read(ctx)
 	if err != nil {
 		return err
@@ -53,10 +53,13 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
 }
 
 func (jc JSONConn) write(ctx context.Context, v interface{}) error {
-	w := jc.Conn.Write(ctx, MessageText)
+	w, err := jc.Conn.Write(ctx, MessageText)
+	if err != nil {
+		return xerrors.Errorf("failed to get message writer: %w", err)
+	}
 
 	e := json.NewEncoder(w)
-	err := e.Encode(v)
+	err = e.Encode(v)
 	if err != nil {
 		return xerrors.Errorf("failed to encode json: %w", err)
 	}
diff --git a/statuscode.go b/statuscode.go
index 2f4f2c0c735c0550621c7317c9f7457fefe882a3..d742195ba82a0ca033e61ee68783bab1f3de25aa 100644
--- a/statuscode.go
+++ b/statuscode.go
@@ -5,7 +5,6 @@ import (
 	"errors"
 	"fmt"
 	"math/bits"
-	"unicode/utf8"
 
 	"golang.org/x/xerrors"
 )
@@ -54,6 +53,12 @@ func (ce CloseError) Error() string {
 }
 
 func parseClosePayload(p []byte) (CloseError, error) {
+	if len(p) == 0 {
+		return CloseError{
+			Code: StatusNoStatusRcvd,
+		}, nil
+	}
+
 	if len(p) < 2 {
 		return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
 	}
@@ -63,9 +68,6 @@ func parseClosePayload(p []byte) (CloseError, error) {
 		Reason: string(p[2:]),
 	}
 
-	if !utf8.ValidString(ce.Reason) {
-		return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason)
-	}
 	if !validWireCloseCode(ce.Code) {
 		return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code)
 	}
diff --git a/websocket.go b/websocket.go
index 52b5d8dba2ce40146c91bfaca524a5e181d58160..79923518038ae6b7fd0052d29e203cbaac4bcc7e 100644
--- a/websocket.go
+++ b/websocket.go
@@ -7,6 +7,7 @@ import (
 	"io"
 	"runtime"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"golang.org/x/xerrors"
@@ -34,25 +35,31 @@ type Conn struct {
 	// Writers should send on write to begin sending
 	// a message and then follow that up with some data
 	// on writeBytes.
+	// Send on control to write a control message.
+	// writeDone will be sent back when the message is written
+	// Send on writeFlush to flush the message and wait for a
+	// ping on writeDone.
+	// writeDone will be closed if the data message write errors.
 	write      chan MessageType
 	control    chan control
 	writeBytes chan []byte
 	writeDone  chan struct{}
+	writeFlush chan struct{}
 
 	// Readers should receive on read to begin reading a message.
 	// Then send a byte slice to readBytes to read into it.
 	// The n of bytes read will be sent on readDone once the read into a slice is complete.
-	// readDone will receive 0 when EOF is reached.
-	read       chan opcode
-	readBytes  chan []byte
-	readDone   chan int
-	readerDone chan struct{}
+	// readDone will be closed if the read fails.
+	// activeReader will be set to 0 on io.EOF.
+	activeReader int64
+	inMsg        bool
+	read         chan opcode
+	readBytes    chan []byte
+	readDone     chan int
 }
 
 func (c *Conn) close(err error) {
-	if err != nil {
-		err = xerrors.Errorf("websocket: connection broken: %w", err)
-	}
+	err = xerrors.Errorf("websocket: connection broken: %w", err)
 
 	c.closeOnce.Do(func() {
 		runtime.SetFinalizer(c, nil)
@@ -76,13 +83,16 @@ func (c *Conn) Subprotocol() string {
 
 func (c *Conn) init() {
 	c.closed = make(chan struct{})
+
 	c.write = make(chan MessageType)
 	c.control = make(chan control)
+	c.writeBytes = make(chan []byte)
 	c.writeDone = make(chan struct{})
+	c.writeFlush = make(chan struct{})
+
 	c.read = make(chan opcode)
-	c.readDone = make(chan int)
 	c.readBytes = make(chan []byte)
-	c.readerDone = make(chan struct{})
+	c.readDone = make(chan int)
 
 	runtime.SetFinalizer(c, func(c *Conn) {
 		c.Close(StatusInternalError, "websocket: connection ended up being garbage collected")
@@ -116,10 +126,10 @@ func (c *Conn) writeFrame(h header, p []byte) {
 }
 
 func (c *Conn) writeLoop() {
+	defer close(c.writeDone)
+
 messageLoop:
 	for {
-		c.writeBytes = make(chan []byte)
-
 		var dataType MessageType
 		select {
 		case <-c.closed:
@@ -160,9 +170,9 @@ messageLoop:
 				case c.writeDone <- struct{}{}:
 					continue
 				}
-			case b, ok := <-c.writeBytes:
+			case b := <-c.writeBytes:
 				h := header{
-					fin:           !ok,
+					fin:           false,
 					opcode:        opcode(dataType),
 					payloadLength: int64(len(b)),
 					masked:        c.client,
@@ -175,24 +185,39 @@ messageLoop:
 
 				c.writeFrame(h, b)
 
-				if !ok {
-					err := c.bw.Flush()
-					if err != nil {
-						c.close(xerrors.Errorf("failed to write to connection: %w", err))
-						return
-					}
+				select {
+				case <-c.closed:
+					return
+				case c.writeDone <- struct{}{}:
+					continue
 				}
+			case <-c.writeFlush:
+				h := header{
+					fin:           true,
+					opcode:        opcode(dataType),
+					payloadLength: 0,
+					masked:        c.client,
+				}
+
+				if firstSent {
+					h.opcode = opContinuation
+				}
+
+				c.writeFrame(h, nil)
 
 				select {
 				case <-c.closed:
 					return
 				case c.writeDone <- struct{}{}:
-					if ok {
-						continue
-					} else {
-						continue messageLoop
-					}
 				}
+
+				err := c.bw.Flush()
+				if err != nil {
+					c.close(xerrors.Errorf("failed to write to connection: %w", err))
+					return
+				}
+
+				continue messageLoop
 			}
 		}
 	}
@@ -225,17 +250,15 @@ func (c *Conn) handleControl(h header) {
 		c.writePong(b)
 	case opPong:
 	case opClose:
-		if len(b) > 0 {
-			ce, err := parseClosePayload(b)
-			if err != nil {
-				c.close(xerrors.Errorf("read invalid close payload: %w", err))
-				return
-			}
-			c.Close(ce.Code, ce.Reason)
+		ce, err := parseClosePayload(b)
+		if err != nil {
+			c.close(xerrors.Errorf("read invalid close payload: %w", err))
+			return
+		}
+		if ce.Code == StatusNoStatusRcvd {
+			c.writeClose(nil, ce)
 		} else {
-			c.writeClose(nil, CloseError{
-				Code: StatusNoStatusRcvd,
-			})
+			c.Close(ce.Code, ce.Reason)
 		}
 	default:
 		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
@@ -243,7 +266,8 @@ func (c *Conn) handleControl(h header) {
 }
 
 func (c *Conn) readLoop() {
-	var indata bool
+	defer close(c.readDone)
+
 	for {
 		h, err := readHeader(c.br)
 		if err != nil {
@@ -263,19 +287,19 @@ func (c *Conn) readLoop() {
 
 		switch h.opcode {
 		case opBinary, opText:
-			if !indata {
-				select {
-				case <-c.closed:
-					return
-				case c.read <- h.opcode:
-				}
-				indata = true
-			} else {
-				c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished")
+			if c.inMsg {
+				c.Close(StatusProtocolError, "cannot read data frame when previous frame is not finished")
+				return
+			}
+
+			select {
+			case <-c.closed:
 				return
+			case c.read <- h.opcode:
+				c.inMsg = true
 			}
 		case opContinuation:
-			if !indata {
+			if !c.inMsg {
 				c.Close(StatusProtocolError, "continuation frame not after data or text frame")
 				return
 			}
@@ -284,47 +308,55 @@ func (c *Conn) readLoop() {
 			return
 		}
 
-		maskPos := 0
-		left := h.payloadLength
-		firstRead := false
-		for left > 0 || !firstRead {
-			select {
-			case <-c.closed:
-				return
-			case b := <-c.readBytes:
-				if int64(len(b)) > left {
-					b = b[:left]
-				}
+		err = c.dataReadLoop(h)
+		if err != nil {
+			c.close(xerrors.Errorf("failed to read from connection: %w", err))
+			return
+		}
+	}
+}
 
-				_, err = io.ReadFull(c.br, b)
-				if err != nil {
-					c.close(xerrors.Errorf("failed to read from connection: %w", err))
-					return
-				}
-				left -= int64(len(b))
+func (c *Conn) dataReadLoop(h header) (err error) {
+	maskPos := 0
+	left := h.payloadLength
+	firstReadDone := false
+	for left > 0 || !firstReadDone {
+		select {
+		case <-c.closed:
+			return c.closeErr
+		case b := <-c.readBytes:
+			if int64(len(b)) > left {
+				b = b[:left]
+			}
 
-				if h.masked {
-					maskPos = mask(h.maskKey, maskPos, b)
-				}
+			_, err := io.ReadFull(c.br, b)
+			if err != nil {
+				return xerrors.Errorf("failed to read from connection: %w", err)
+			}
+			left -= int64(len(b))
 
-				select {
-				case <-c.closed:
-					return
-				case c.readDone <- len(b):
-					firstRead = true
-				}
+			if h.masked {
+				maskPos = mask(h.maskKey, maskPos, b)
+			}
+
+			// Must set this before we signal the read is done.
+			// The reader will use this to return io.EOF and
+			// c.Read will use it to check if the reader has been completed.
+			if left == 0 && h.fin {
+				atomic.StoreInt64(&c.activeReader, 0)
+				c.inMsg = false
 			}
-		}
 
-		if h.fin {
-			indata = false
 			select {
 			case <-c.closed:
-				return
-			case c.readerDone <- struct{}{}:
+				return c.closeErr
+			case c.readDone <- len(b):
+				firstReadDone = true
 			}
 		}
 	}
+
+	return nil
 }
 
 func (c *Conn) writePong(p []byte) error {
@@ -404,77 +436,65 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
 }
 
 // Write returns a writer bounded by the context that will write
-// a WebSocket data frame of type dataType to the connection.
-// Ensure you close the messageWriter once you have written to entire message.
-// Concurrent calls to messageWriter are ok.
-func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser {
-	// TODO acquire write here, move state into Conn and make messageWriter allocation free.
-	return &messageWriter{
-		c:        c,
-		ctx:      ctx,
-		datatype: dataType,
+// a WebSocket message of type dataType to the connection.
+// Ensure you close the writer once you have written the entire message.
+// Concurrent calls to Write are ok.
+func (c *Conn) Write(ctx context.Context, dataType MessageType) (io.WriteCloser, error) {
+	select {
+	case <-c.closed:
+		return nil, c.closeErr
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	case c.write <- dataType:
+		return messageWriter{
+			ctx: ctx,
+			c:   c,
+		}, nil
 	}
 }
 
 // messageWriter enables writing to a WebSocket connection.
-// Ensure you close the messageWriter once you have written to entire message.
 type messageWriter struct {
-	datatype     MessageType
-	ctx          context.Context
-	c            *Conn
-	acquiredLock bool
+	ctx context.Context
+	c   *Conn
 }
 
 // Write writes the given bytes to the WebSocket connection.
 // The frame will automatically be fragmented as appropriate
 // with the buffers obtained from http.Hijacker.
 // Please ensure you call Close once you have written the full message.
-func (w *messageWriter) Write(p []byte) (int, error) {
-	err := w.acquire()
-	if err != nil {
-		return 0, err
-	}
-
+func (w messageWriter) Write(p []byte) (int, error) {
 	select {
 	case <-w.c.closed:
 		return 0, w.c.closeErr
 	case w.c.writeBytes <- p:
 		select {
-		case <-w.c.closed:
-			return 0, w.c.closeErr
-		case <-w.c.writeDone:
-			return len(p), nil
 		case <-w.ctx.Done():
+			w.c.close(xerrors.Errorf("write timed out: %w", w.ctx.Err()))
+			<-w.c.readDone
 			return 0, w.ctx.Err()
+		case _, ok := <-w.c.writeDone:
+			if !ok {
+				return 0, w.c.closeErr
+			}
+			return len(p), nil
 		}
 	case <-w.ctx.Done():
 		return 0, w.ctx.Err()
 	}
 }
 
-func (w *messageWriter) acquire() error {
-	if !w.acquiredLock {
-		select {
-		case <-w.c.closed:
-			return w.c.closeErr
-		case w.c.write <- w.datatype:
-			w.acquiredLock = true
-		case <-w.ctx.Done():
-			return w.ctx.Err()
-		}
-	}
-	return nil
-}
-
 // Close flushes the frame to the connection.
 // This must be called for every messageWriter.
-func (w *messageWriter) Close() error {
-	err := w.acquire()
-	if err != nil {
-		return err
+func (w messageWriter) Close() error {
+	select {
+	case <-w.c.closed:
+		return w.c.closeErr
+	case <-w.ctx.Done():
+		return w.ctx.Err()
+	case w.c.writeFlush <- struct{}{}:
 	}
 
-	close(w.c.writeBytes)
 	select {
 	case <-w.c.closed:
 		return w.c.closeErr
@@ -485,26 +505,45 @@ func (w *messageWriter) Close() error {
 	}
 }
 
-// ReadMessage will wait until there is a WebSocket data frame to read from the connection.
-// It returns the type of the data, a reader to read it and also an error.
-// Please use SetContext on the reader to bound the read operation.
+// ReadMessage will wait until there is a WebSocket data message to read from the connection.
+// It returns the type of the message and a reader to read it.
+// The passed context will also bound the reader.
 // Your application must keep reading messages for the Conn to automatically respond to ping
-// and close frames.
+// and close frames and not become stuck waiting for a data message to be read.
+// Please ensure to read the full message from io.Reader.
+// You can only read a single message at a time.
 func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) {
-	// TODO error if the reader is not done
+	for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
+		select {
+		case <-c.closed:
+			return 0, nil, c.closeErr
+		case c.readBytes <- nil:
+			select {
+			case <-ctx.Done():
+				return 0, nil, ctx.Err()
+			case _, ok := <-c.readDone:
+				if !ok {
+					return 0, nil, c.closeErr
+				}
+				if atomic.LoadInt64(&c.activeReader) == 1 {
+					return 0, nil, xerrors.New("websocket: previous message not fully read")
+				}
+			}
+		case <-ctx.Done():
+			return 0, nil, ctx.Err()
+		}
+	}
+
 	select {
-	case <-c.readerDone:
-		// The previous reader just hit a io.EOF, we handle it for users
-		return c.Read(ctx)
 	case <-c.closed:
-		return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr)
+		return 0, nil, xerrors.Errorf("websocket: failed to read message: %w", c.closeErr)
 	case opcode := <-c.read:
 		return MessageType(opcode), messageReader{
 			ctx: ctx,
 			c:   c,
 		}, nil
 	case <-ctx.Done():
-		return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err())
+		return 0, nil, xerrors.Errorf("websocket: failed to read message: %w", ctx.Err())
 	}
 }
 
@@ -518,30 +557,38 @@ type messageReader struct {
 func (r messageReader) Read(p []byte) (int, error) {
 	n, err := r.read(p)
 	if err != nil {
-		// Have to return io.EOF directly for now.
+		// Have to return io.EOF directly for now, cannot wrap.
 		if err == io.EOF {
-			return 0, io.EOF
+			return n, io.EOF
 		}
-		return n, xerrors.Errorf("failed to read: %w", err)
+		return n, xerrors.Errorf("websocket: failed to read: %w", err)
 	}
 	return n, nil
 }
 
-func (r messageReader) read(p []byte) (int, error) {
+func (r messageReader) read(p []byte) (_ int, err error) {
+	if atomic.LoadInt64(&r.c.activeReader) == 0 {
+		return 0, io.EOF
+	}
+
 	select {
 	case <-r.c.closed:
 		return 0, r.c.closeErr
-	case <-r.c.readerDone:
-		return 0, io.EOF
 	case r.c.readBytes <- p:
-		// TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes.
 		select {
-		case <-r.c.closed:
-			return 0, r.c.closeErr
-		case n := <-r.c.readDone:
-			return n, nil
 		case <-r.ctx.Done():
+			r.c.close(xerrors.Errorf("read timed out: %w", r.ctx.Err()))
+			// Wait for readloop to complete so we know p is done.
+			<-r.c.readDone
 			return 0, r.ctx.Err()
+		case n, ok := <-r.c.readDone:
+			if !ok {
+				return 0, r.c.closeErr
+			}
+			if atomic.LoadInt64(&r.c.activeReader) == 0 {
+				return n, io.EOF
+			}
+			return n, nil
 		}
 	case <-r.ctx.Done():
 		return 0, r.ctx.Err()
diff --git a/websocket_test.go b/websocket_test.go
index 868b69a37ccb7e0700c89627eeb1403a02d6da4f..d6d222d55e9aac2e88736975b36ee1ca7da7428b 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -196,14 +196,25 @@ func TestHandshake(t *testing.T) {
 				ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
 				defer cancel()
 
-				jc := websocket.JSONConn{
-					Conn: c,
-				}
+				write := func() error {
+					jc := websocket.JSONConn{
+						Conn: c,
+					}
 
-				v := map[string]interface{}{
-					"anmol": "wowow",
+					v := map[string]interface{}{
+						"anmol": "wowow",
+					}
+					err = jc.Write(ctx, v)
+					if err != nil {
+						return err
+					}
+					return nil
 				}
-				err = jc.Write(ctx, v)
+				err = write()
+				if err != nil {
+					return err
+				}
+				err = write()
 				if err != nil {
 					return err
 				}
@@ -222,17 +233,29 @@ func TestHandshake(t *testing.T) {
 					Conn: c,
 				}
 
-				var v interface{}
-				err = jc.Read(ctx, &v)
+				read := func() error {
+					var v interface{}
+					err = jc.Read(ctx, &v)
+					if err != nil {
+						return err
+					}
+
+					exp := map[string]interface{}{
+						"anmol": "wowow",
+					}
+					if !reflect.DeepEqual(exp, v) {
+						return xerrors.Errorf("expected %v but got %v", exp, v)
+					}
+					return nil
+				}
+				err = read()
 				if err != nil {
 					return err
 				}
-
-				exp := map[string]interface{}{
-					"anmol": "wowow",
-				}
-				if !reflect.DeepEqual(exp, v) {
-					return xerrors.Errorf("expected %v but got %v", exp, v)
+				// Read twice to ensure the un EOFed previous reader works correctly.
+				err = read()
+				if err != nil {
+					return err
 				}
 
 				c.Close(websocket.StatusNormalClosure, "")
@@ -292,29 +315,14 @@ func TestHandshake(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
 			t.Parallel()
 
-			var conns int64
-			s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-				atomic.AddInt64(&conns, 1)
-				defer atomic.AddInt64(&conns, -1)
-
+			s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
 				err := tc.server(w, r)
 				if err != nil {
 					t.Errorf("server failed: %+v", err)
 					return
 				}
-			}))
-			defer func() {
-				s.Close()
-
-				ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-				defer cancel()
-
-				for atomic.LoadInt64(&conns) > 0 {
-					if ctx.Err() != nil {
-						t.Fatalf("waiting for server to come down timed out: %v", ctx.Err())
-					}
-				}
-			}()
+			})
+			defer closeFn()
 
 			wsURL := strings.Replace(s.URL, "http", "ws", 1)
 
@@ -329,6 +337,28 @@ func TestHandshake(t *testing.T) {
 	}
 }
 
+func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) {
+	var conns int64
+	s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		atomic.AddInt64(&conns, 1)
+		defer atomic.AddInt64(&conns, -1)
+
+		fn.ServeHTTP(w, r)
+	}))
+	return s, func() {
+		s.Close()
+
+		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+		defer cancel()
+
+		for atomic.LoadInt64(&conns) > 0 {
+			if ctx.Err() != nil {
+				tb.Fatalf("waiting for server to come down timed out: %v", ctx.Err())
+			}
+		}
+	}
+}
+
 // https://github.com/crossbario/autobahn-python/tree/master/wstest
 func TestAutobahnServer(t *testing.T) {
 	t.Parallel()
@@ -341,7 +371,7 @@ func TestAutobahnServer(t *testing.T) {
 			t.Logf("server handshake failed: %+v", err)
 			return
 		}
-		echoLoop(r.Context(), c, t)
+		echoLoop(r.Context(), c)
 	}))
 	defer s.Close()
 
@@ -354,7 +384,7 @@ func TestAutobahnServer(t *testing.T) {
 			},
 		},
 		"cases":         []string{"*"},
-		"exclude-cases": []string{"6.*", "12.*", "13.*"},
+		"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
 	}
 	specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json")
 	if err != nil {
@@ -388,21 +418,25 @@ func TestAutobahnServer(t *testing.T) {
 	checkWSTestIndex(t, "./wstest_reports/server/index.json")
 }
 
-func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) {
+func echoLoop(ctx context.Context, c *websocket.Conn) {
 	defer c.Close(websocket.StatusInternalError, "")
 
-	echo := func() error {
-		ctx, cancel := context.WithTimeout(ctx, time.Second*30)
-		defer cancel()
+	ctx, cancel := context.WithTimeout(ctx, time.Minute)
+	defer cancel()
 
+	b := make([]byte, 32768)
+	echo := func() error {
 		typ, r, err := c.Read(ctx)
 		if err != nil {
 			return err
 		}
 
-		w := c.Write(ctx, typ)
+		w, err := c.Write(ctx, typ)
+		if err != nil {
+			return err
+		}
 
-		_, err = io.Copy(w, r)
+		_, err = io.CopyBuffer(w, r, b)
 		if err != nil {
 			return err
 		}
@@ -431,7 +465,7 @@ func TestAutobahnClient(t *testing.T) {
 		"url":           "ws://localhost:9001",
 		"outdir":        "wstest_reports/client",
 		"cases":         []string{"*"},
-		"exclude-cases": []string{"6.*", "12.*", "13.*"},
+		"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
 	}
 	specFile, err := ioutil.TempFile("", "websocket_fuzzingserver.json")
 	if err != nil {
@@ -507,7 +541,7 @@ func TestAutobahnClient(t *testing.T) {
 			if err != nil {
 				t.Fatalf("failed to dial: %v", err)
 			}
-			echoLoop(ctx, c, t)
+			echoLoop(ctx, c)
 		}()
 	}