diff --git a/netconn.go b/netconn.go index a6f902da9f6d6ab44cddb73d3d8dd49d020cf97e..a7c9bf7fcf6a444deac547fa84ab89e3ee28ca82 100644 --- a/netconn.go +++ b/netconn.go @@ -21,8 +21,11 @@ import ( // Every Write to the net.Conn will correspond to a message write of // the given type on *websocket.Conn. // -// If a message is read that is not of the correct type, an error -// will be thrown. +// The passed ctx bounds the lifetime of the net.Conn. If cancelled, +// all reads and writes on the net.Conn will be cancelled. +// +// If a message is read that is not of the correct type, the connection +// will be closed with StatusUnsupportedData and an error will be returned. // // Close will close the *websocket.Conn with StatusNormalClosure. // @@ -35,20 +38,20 @@ import ( // // A received StatusNormalClosure or StatusGoingAway close frame will be translated to // io.EOF when reading. -func NetConn(c *Conn, msgType MessageType) net.Conn { +func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { nc := &netConn{ c: c, msgType: msgType, } var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(context.Background()) + nc.writeContext, cancel = context.WithCancel(ctx) nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(context.Background()) + nc.readContext, cancel = context.WithCancel(ctx) nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) if !nc.readTimer.Stop() { <-nc.readTimer.C diff --git a/websocket_test.go b/websocket_test.go index 27750bca1f44e36169a80f98c764ae352c60723a..979b092cf8b2c4dcbc1aabfb21952a447aba5ae9 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -264,7 +264,7 @@ func TestConn(t *testing.T) { { name: "netConn", server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(c, websocket.MessageBinary) + nc := websocket.NetConn(ctx, c, websocket.MessageBinary) defer nc.Close() nc.SetWriteDeadline(time.Time{}) @@ -290,7 +290,7 @@ func TestConn(t *testing.T) { return nil }, client: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(c, websocket.MessageBinary) + nc := websocket.NetConn(ctx, c, websocket.MessageBinary) nc.SetReadDeadline(time.Time{}) time.Sleep(1) @@ -317,7 +317,7 @@ func TestConn(t *testing.T) { { name: "netConn/badReadMsgType", server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(c, websocket.MessageBinary) + nc := websocket.NetConn(ctx, c, websocket.MessageBinary) nc.SetDeadline(time.Now().Add(time.Second * 15)) @@ -337,7 +337,7 @@ func TestConn(t *testing.T) { { name: "netConn/badRead", server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(c, websocket.MessageBinary) + nc := websocket.NetConn(ctx, c, websocket.MessageBinary) defer nc.Close() nc.SetDeadline(time.Now().Add(time.Second * 15))