From 711cce472d00c014c451069bb09a1b5e8c911154 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Thu, 15 Aug 2019 16:51:07 -0700 Subject: [PATCH] Add msgType parameter to NetConn adapter Closes #113 --- netconn.go | 24 +++++++++++++++--------- websocket_test.go | 4 ++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/netconn.go b/netconn.go index 06cbc2f..2578770 100644 --- a/netconn.go +++ b/netconn.go @@ -2,6 +2,7 @@ package websocket import ( "context" + "fmt" "io" "math" "net" @@ -17,8 +18,11 @@ import ( // correctly and so provided in the library. // See https://github.com/nhooyr/websocket/issues/100. // -// Every Write to the net.Conn will correspond to a binary message -// write on *webscoket.Conn. +// Every Write to the net.Conn will correspond to a message write of +// the given type on *websocket.Conn. +// +// If a message is read that is not of the correct type, an error +// will be thrown. // // Close will close the *websocket.Conn with StatusNormalClosure. // @@ -30,9 +34,10 @@ import ( // and "websocket/unknown-addr" for String. // // A received StatusNormalClosure close frame will be translated to EOF when reading. -func NetConn(c *Conn) net.Conn { +func NetConn(c *Conn, msgType MessageType) net.Conn { nc := &netConn{ - c: c, + c: c, + msgType: msgType, } var cancel context.CancelFunc @@ -52,7 +57,8 @@ func NetConn(c *Conn) net.Conn { } type netConn struct { - c *Conn + c *Conn + msgType MessageType writeTimer *time.Timer writeContext context.Context @@ -71,7 +77,7 @@ func (c *netConn) Close() error { } func (c *netConn) Write(p []byte) (int, error) { - err := c.c.Write(c.writeContext, MessageBinary, p) + err := c.c.Write(c.writeContext, c.msgType, p) if err != nil { return 0, err } @@ -93,9 +99,9 @@ func (c *netConn) Read(p []byte) (int, error) { } return 0, err } - if typ != MessageBinary { - c.c.Close(StatusUnsupportedData, "can only accept binary messages") - return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ) + if typ != c.msgType { + c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType)) + return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ) } c.reader = r } diff --git a/websocket_test.go b/websocket_test.go index 46f9c83..06e0fc6 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -127,7 +127,7 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") - nc := websocket.NetConn(c) + nc := websocket.NetConn(c, websocket.MessageBinary) defer nc.Close() nc.SetWriteDeadline(time.Time{}) @@ -152,7 +152,7 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") - nc := websocket.NetConn(c) + nc := websocket.NetConn(c, websocket.MessageBinary) defer nc.Close() nc.SetReadDeadline(time.Time{}) -- GitLab