From 711cce472d00c014c451069bb09a1b5e8c911154 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Thu, 15 Aug 2019 16:51:07 -0700
Subject: [PATCH] Add msgType parameter to NetConn adapter

Closes #113
---
 netconn.go        | 24 +++++++++++++++---------
 websocket_test.go |  4 ++--
 2 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/netconn.go b/netconn.go
index 06cbc2f..2578770 100644
--- a/netconn.go
+++ b/netconn.go
@@ -2,6 +2,7 @@ package websocket
 
 import (
 	"context"
+	"fmt"
 	"io"
 	"math"
 	"net"
@@ -17,8 +18,11 @@ import (
 // correctly and so provided in the library.
 // See https://github.com/nhooyr/websocket/issues/100.
 //
-// Every Write to the net.Conn will correspond to a binary message
-// write on *webscoket.Conn.
+// Every Write to the net.Conn will correspond to a message write of
+// the given type on *websocket.Conn.
+//
+// If a message is read that is not of the correct type, an error
+// will be thrown.
 //
 // Close will close the *websocket.Conn with StatusNormalClosure.
 //
@@ -30,9 +34,10 @@ import (
 // and "websocket/unknown-addr" for String.
 //
 // A received StatusNormalClosure close frame will be translated to EOF when reading.
-func NetConn(c *Conn) net.Conn {
+func NetConn(c *Conn, msgType MessageType) net.Conn {
 	nc := &netConn{
-		c: c,
+		c:       c,
+		msgType: msgType,
 	}
 
 	var cancel context.CancelFunc
@@ -52,7 +57,8 @@ func NetConn(c *Conn) net.Conn {
 }
 
 type netConn struct {
-	c *Conn
+	c       *Conn
+	msgType MessageType
 
 	writeTimer   *time.Timer
 	writeContext context.Context
@@ -71,7 +77,7 @@ func (c *netConn) Close() error {
 }
 
 func (c *netConn) Write(p []byte) (int, error) {
-	err := c.c.Write(c.writeContext, MessageBinary, p)
+	err := c.c.Write(c.writeContext, c.msgType, p)
 	if err != nil {
 		return 0, err
 	}
@@ -93,9 +99,9 @@ func (c *netConn) Read(p []byte) (int, error) {
 			}
 			return 0, err
 		}
-		if typ != MessageBinary {
-			c.c.Close(StatusUnsupportedData, "can only accept binary messages")
-			return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ)
+		if typ != c.msgType {
+			c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType))
+			return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ)
 		}
 		c.reader = r
 	}
diff --git a/websocket_test.go b/websocket_test.go
index 46f9c83..06e0fc6 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -127,7 +127,7 @@ func TestHandshake(t *testing.T) {
 				}
 				defer c.Close(websocket.StatusInternalError, "")
 
-				nc := websocket.NetConn(c)
+				nc := websocket.NetConn(c, websocket.MessageBinary)
 				defer nc.Close()
 
 				nc.SetWriteDeadline(time.Time{})
@@ -152,7 +152,7 @@ func TestHandshake(t *testing.T) {
 				}
 				defer c.Close(websocket.StatusInternalError, "")
 
-				nc := websocket.NetConn(c)
+				nc := websocket.NetConn(c, websocket.MessageBinary)
 				defer nc.Close()
 
 				nc.SetReadDeadline(time.Time{})
-- 
GitLab