diff --git a/netconn.go b/netconn.go
index a6f902da9f6d6ab44cddb73d3d8dd49d020cf97e..a7c9bf7fcf6a444deac547fa84ab89e3ee28ca82 100644
--- a/netconn.go
+++ b/netconn.go
@@ -21,8 +21,11 @@ import (
 // Every Write to the net.Conn will correspond to a message write of
 // the given type on *websocket.Conn.
 //
-// If a message is read that is not of the correct type, an error
-// will be thrown.
+// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
+// all reads and writes on the net.Conn will be cancelled.
+//
+// If a message is read that is not of the correct type, the connection
+// will be closed with StatusUnsupportedData and an error will be returned.
 //
 // Close will close the *websocket.Conn with StatusNormalClosure.
 //
@@ -35,20 +38,20 @@ import (
 //
 // A received StatusNormalClosure or StatusGoingAway close frame will be translated to
 // io.EOF when reading.
-func NetConn(c *Conn, msgType MessageType) net.Conn {
+func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
 	nc := &netConn{
 		c:       c,
 		msgType: msgType,
 	}
 
 	var cancel context.CancelFunc
-	nc.writeContext, cancel = context.WithCancel(context.Background())
+	nc.writeContext, cancel = context.WithCancel(ctx)
 	nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
 	if !nc.writeTimer.Stop() {
 		<-nc.writeTimer.C
 	}
 
-	nc.readContext, cancel = context.WithCancel(context.Background())
+	nc.readContext, cancel = context.WithCancel(ctx)
 	nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
 	if !nc.readTimer.Stop() {
 		<-nc.readTimer.C
diff --git a/websocket_test.go b/websocket_test.go
index 27750bca1f44e36169a80f98c764ae352c60723a..979b092cf8b2c4dcbc1aabfb21952a447aba5ae9 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -264,7 +264,7 @@ func TestConn(t *testing.T) {
 		{
 			name: "netConn",
 			server: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(c, websocket.MessageBinary)
+				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
 				defer nc.Close()
 
 				nc.SetWriteDeadline(time.Time{})
@@ -290,7 +290,7 @@ func TestConn(t *testing.T) {
 				return nil
 			},
 			client: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(c, websocket.MessageBinary)
+				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
 
 				nc.SetReadDeadline(time.Time{})
 				time.Sleep(1)
@@ -317,7 +317,7 @@ func TestConn(t *testing.T) {
 		{
 			name: "netConn/badReadMsgType",
 			server: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(c, websocket.MessageBinary)
+				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
 
 				nc.SetDeadline(time.Now().Add(time.Second * 15))
 
@@ -337,7 +337,7 @@ func TestConn(t *testing.T) {
 		{
 			name: "netConn/badRead",
 			server: func(ctx context.Context, c *websocket.Conn) error {
-				nc := websocket.NetConn(c, websocket.MessageBinary)
+				nc := websocket.NetConn(ctx, c, websocket.MessageBinary)
 				defer nc.Close()
 
 				nc.SetDeadline(time.Now().Add(time.Second * 15))