From a2a2d31cb3d23134087d033f88b340bf3b25b686 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 1 Jul 2019 10:29:59 -0400
Subject: [PATCH] Add NetConn adapter

Closes #100
---
 netconn.go        | 116 ++++++++++++++++++++++++++++++++++++++++++++++
 websocket_test.go |  48 +++++++++++++++++++
 2 files changed, 164 insertions(+)
 create mode 100644 netconn.go

diff --git a/netconn.go b/netconn.go
new file mode 100644
index 0000000..0de2f1c
--- /dev/null
+++ b/netconn.go
@@ -0,0 +1,116 @@
+package websocket
+
+import (
+	"context"
+	"golang.org/x/xerrors"
+	"io"
+	"math"
+	"net"
+	"time"
+)
+
+// NetConn converts a *websocket.Conn into a net.Conn.
+// Every Write to the net.Conn will correspond to a binary message
+// write on *webscoket.Conn.
+// Close will close the *websocket.Conn with StatusNormalClosure.
+// When a deadline is hit, the connection will be closed. This is
+// different from most net.Conn implementations where only the
+// reading/writing goroutines are interrupted but the connection is kept alive.
+// The Addr methods will return zero value net.TCPAddr.
+func NetConn(c *Conn) net.Conn {
+	nc := &netConn{
+		c: c,
+	}
+
+	var cancel context.CancelFunc
+	nc.writeContext, cancel = context.WithCancel(context.Background())
+	nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
+	nc.writeTimer.Stop()
+
+	nc.readContext, cancel = context.WithCancel(context.Background())
+	nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
+	nc.readTimer.Stop()
+
+	return nc
+}
+
+type netConn struct {
+	c *Conn
+
+	writeTimer   *time.Timer
+	writeContext context.Context
+
+	readTimer   *time.Timer
+	readContext context.Context
+
+	reader io.Reader
+}
+
+var _ net.Conn = &netConn{}
+
+func (c *netConn) Close() error {
+	return c.c.Close(StatusNormalClosure, "")
+}
+
+func (c *netConn) Write(p []byte) (int, error) {
+	err := c.c.Write(c.writeContext, MessageBinary, p)
+	if err != nil {
+		return 0, err
+	}
+	return len(p), nil
+}
+
+func (c *netConn) Read(p []byte) (int, error) {
+	if c.reader == nil {
+		typ, r, err := c.c.Reader(c.readContext)
+		if err != nil {
+			return 0, err
+		}
+		if typ != MessageBinary {
+			c.c.Close(StatusUnsupportedData, "can only accept binary messages")
+			return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ)
+		}
+		c.reader = r
+	}
+
+	n, err := c.reader.Read(p)
+	if err == io.EOF {
+		c.reader = nil
+	}
+	return n, err
+}
+
+type unknownAddr struct {
+}
+
+func (a unknownAddr) Network() string {
+	return "unknown"
+}
+
+func (a unknownAddr) String() string {
+	return "unknown"
+}
+
+func (c *netConn) RemoteAddr() net.Addr {
+	return unknownAddr{}
+}
+
+func (c *netConn) LocalAddr() net.Addr {
+	return unknownAddr{}
+}
+
+func (c *netConn) SetDeadline(t time.Time) error {
+	c.SetWriteDeadline(t)
+	c.SetReadDeadline(t)
+	return nil
+}
+
+func (c *netConn) SetWriteDeadline(t time.Time) error {
+	c.writeTimer.Reset(t.Sub(time.Now()))
+	return nil
+}
+
+func (c *netConn) SetReadDeadline(t time.Time) error {
+	c.readTimer.Reset(t.Sub(time.Now()))
+	return nil
+}
diff --git a/websocket_test.go b/websocket_test.go
index 2d7db27..2112ff7 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -118,6 +118,54 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 		},
+		{
+			name: "netConn",
+			server: func(w http.ResponseWriter, r *http.Request) error {
+				c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
+				if err != nil {
+					return err
+				}
+				defer c.Close(websocket.StatusInternalError, "")
+
+				nc := websocket.NetConn(c)
+				defer nc.Close()
+
+				nc.SetWriteDeadline(time.Now().Add(time.Second * 10))
+
+				_, err = nc.Write([]byte("hello"))
+				if err != nil {
+					return err
+				}
+
+				return nil
+			},
+			client: func(ctx context.Context, u string) error {
+				c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
+					Subprotocols: []string{"meow"},
+				})
+				if err != nil {
+					return err
+				}
+				defer c.Close(websocket.StatusInternalError, "")
+
+				nc := websocket.NetConn(c)
+				defer nc.Close()
+
+				nc.SetReadDeadline(time.Now().Add(time.Second * 10))
+
+				p := make([]byte, len("hello"))
+				_, err = io.ReadFull(nc, p)
+				if err != nil {
+					return err
+				}
+
+				if string(p) != "hello" {
+					return xerrors.Errorf("unexpected payload %q received", string(p))
+				}
+
+				return nil
+			},
+		},
 		{
 			name: "defaultSubprotocol",
 			server: func(w http.ResponseWriter, r *http.Request) error {
-- 
GitLab