From 6f6fa430a6e88699b3b8aef5d1b8499100f3e8b9 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 30 Dec 2019 22:03:21 -0500
Subject: [PATCH] Refactor autobahn

---
 accept.go        |   2 -
 autobahn_test.go | 319 ++++++++++++++++++++++++++---------------------
 close.go         |   2 +
 compress.go      |  18 +--
 conn.go          |   7 +-
 go.mod           |   2 -
 go.sum           |   7 --
 read.go          |  74 ++++++-----
 write.go         |  45 ++++---
 9 files changed, 260 insertions(+), 216 deletions(-)

diff --git a/accept.go b/accept.go
index ea7beeb..f16180f 100644
--- a/accept.go
+++ b/accept.go
@@ -37,8 +37,6 @@ type AcceptOptions struct {
 	// If used incorrectly your WebSocket server will be open to CSRF attacks.
 	InsecureSkipVerify bool
 
-	// CompressionMode sets the compression mode.
-	// See the docs on CompressionMode.
 	CompressionMode CompressionMode
 }
 
diff --git a/autobahn_test.go b/autobahn_test.go
index 6b3b5b7..16384b2 100644
--- a/autobahn_test.go
+++ b/autobahn_test.go
@@ -9,7 +9,6 @@ import (
 	"io/ioutil"
 	"net"
 	"net/http"
-	"net/http/httptest"
 	"os"
 	"os/exec"
 	"strconv"
@@ -17,9 +16,27 @@ import (
 	"testing"
 	"time"
 
+	"cdr.dev/slog/sloggers/slogtest/assert"
+
 	"nhooyr.io/websocket"
+	"nhooyr.io/websocket/internal/errd"
 )
 
+var excludedAutobahnCases = []string{
+	// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
+	// more performance overhead.
+	"6.*", "7.5.1",
+
+	// We skip the tests related to requestMaxWindowBits as that is unimplemented due
+	// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
+	"13.3.*", "13.4.*", "13.5.*", "13.6.*",
+
+	"12.*",
+	"13.*",
+}
+
+var autobahnCases = []string{"*"}
+
 // https://github.com/crossbario/autobahn-python/tree/master/wstest
 func TestAutobahn(t *testing.T) {
 	t.Parallel()
@@ -35,19 +52,17 @@ func TestAutobahn(t *testing.T) {
 func testServerAutobahn(t *testing.T) {
 	t.Parallel()
 
-	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
 		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
 			Subprotocols: []string{"echo"},
 		})
-		if err != nil {
-			t.Logf("server handshake failed: %+v", err)
-			return
-		}
-		echoLoop(r.Context(), c)
-	}))
-	defer s.Close()
+		assert.Success(t, "accept", err)
+		err = echoLoop(r.Context(), c)
+		assertCloseStatus(t, websocket.StatusNormalClosure, err)
+	}, false)
+	defer closeFn()
 
-	spec := map[string]interface{}{
+	specFile, err := tempJSONFile(map[string]interface{}{
 		"outdir": "ci/out/wstestServerReports",
 		"servers": []interface{}{
 			map[string]interface{}{
@@ -55,92 +70,105 @@ func testServerAutobahn(t *testing.T) {
 				"url":   strings.Replace(s.URL, "http", "ws", 1),
 			},
 		},
-		"cases": []string{"*"},
-		// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
-		// more performance overhead. 7.5.1 is the same.
-		"exclude-cases": []string{"6.*", "7.5.1"},
-	}
-	specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json")
-	if err != nil {
-		t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err)
-	}
-	defer specFile.Close()
-
-	e := json.NewEncoder(specFile)
-	e.SetIndent("", "\t")
-	err = e.Encode(spec)
-	if err != nil {
-		t.Fatalf("failed to write spec: %v", err)
-	}
-
-	err = specFile.Close()
-	if err != nil {
-		t.Fatalf("failed to close file: %v", err)
-	}
+		"cases":         autobahnCases,
+		"exclude-cases": excludedAutobahnCases,
+	})
+	assert.Success(t, "tempJSONFile", err)
 
-	ctx := context.Background()
-	ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
 	defer cancel()
 
-	args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()}
+	args := []string{"--mode", "fuzzingclient", "--spec", specFile}
 	wstest := exec.CommandContext(ctx, "wstest", args...)
-	out, err := wstest.CombinedOutput()
-	if err != nil {
-		t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out)
-	}
+	_, err = wstest.CombinedOutput()
+	assert.Success(t, "wstest", err)
 
 	checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
 }
 
-func unusedListenAddr() (string, error) {
-	l, err := net.Listen("tcp", "localhost:0")
-	if err != nil {
-		return "", err
-	}
-	l.Close()
-	return l.Addr().String(), nil
-}
-
 func testClientAutobahn(t *testing.T) {
 	t.Parallel()
 
-	serverAddr, err := unusedListenAddr()
-	if err != nil {
-		t.Fatalf("failed to get unused listen addr for wstest: %v", err)
-	}
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
+	defer cancel()
 
-	wsServerURL := "ws://" + serverAddr
+	wstestURL, closeFn, err := wstestClientServer(ctx)
+	assert.Success(t, "wstestClient", err)
+	defer closeFn()
 
-	spec := map[string]interface{}{
-		"url":    wsServerURL,
-		"outdir": "ci/out/wstestClientReports",
-		"cases":  []string{"*"},
-		// See TestAutobahnServer for the reasons why we exclude these.
-		"exclude-cases": []string{"6.*", "7.5.1"},
-	}
-	specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json")
-	if err != nil {
-		t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err)
+	err = waitWS(ctx, wstestURL)
+	assert.Success(t, "waitWS", err)
+
+	cases, err := wstestCaseCount(ctx, wstestURL)
+	assert.Success(t, "wstestCaseCount", err)
+
+	t.Run("cases", func(t *testing.T) {
+		for i := 1; i <= cases; i++ {
+			i := i
+			t.Run("", func(t *testing.T) {
+				t.Parallel()
+
+				ctx, cancel := context.WithTimeout(ctx, time.Second*45)
+				defer cancel()
+
+				c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil)
+				assert.Success(t, "autobahn dial", err)
+
+				err = echoLoop(ctx, c)
+				t.Logf("echoLoop: %+v", err)
+			})
+		}
+	})
+
+	c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil)
+	assert.Success(t, "dial", err)
+	c.Close(websocket.StatusNormalClosure, "")
+
+	checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
+}
+
+func waitWS(ctx context.Context, url string) error {
+	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
+	defer cancel()
+
+	for ctx.Err() == nil {
+		c, _, err := websocket.Dial(ctx, url, nil)
+		if err != nil {
+			continue
+		}
+		c.Close(websocket.StatusNormalClosure, "")
+		return nil
 	}
-	defer specFile.Close()
 
-	e := json.NewEncoder(specFile)
-	e.SetIndent("", "\t")
-	err = e.Encode(spec)
+	return ctx.Err()
+}
+
+func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) {
+	serverAddr, err := unusedListenAddr()
 	if err != nil {
-		t.Fatalf("failed to write spec: %v", err)
+		return "", nil, err
 	}
 
-	err = specFile.Close()
+	url = "ws://" + serverAddr
+
+	specFile, err := tempJSONFile(map[string]interface{}{
+		"url":           url,
+		"outdir":        "ci/out/wstestClientReports",
+		"cases":         autobahnCases,
+		"exclude-cases": excludedAutobahnCases,
+	})
 	if err != nil {
-		t.Fatalf("failed to close file: %v", err)
+		return "", nil, fmt.Errorf("failed to write spec: %w", err)
 	}
 
-	ctx := context.Background()
-	ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
-	defer cancel()
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
+	defer func() {
+		if err != nil {
+			cancel()
+		}
+	}()
 
-	args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(),
+	args := []string{"--mode", "fuzzingserver", "--spec", specFile,
 		// Disables some server that runs as part of fuzzingserver mode.
 		// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
 		"--webport=0",
@@ -148,101 +176,104 @@ func testClientAutobahn(t *testing.T) {
 	wstest := exec.CommandContext(ctx, "wstest", args...)
 	err = wstest.Start()
 	if err != nil {
-		t.Fatal(err)
+		return "", nil, fmt.Errorf("failed to start wstest: %w", err)
 	}
-	defer func() {
-		err := wstest.Process.Kill()
-		if err != nil {
-			t.Error(err)
-		}
-	}()
-
-	// Let it come up.
-	time.Sleep(time.Second * 5)
-
-	var cases int
-	func() {
-		c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil)
-		if err != nil {
-			t.Fatal(err)
-		}
-		defer c.Close(websocket.StatusInternalError, "")
-
-		_, r, err := c.Reader(ctx)
-		if err != nil {
-			t.Fatal(err)
-		}
-		b, err := ioutil.ReadAll(r)
-		if err != nil {
-			t.Fatal(err)
-		}
-		cases, err = strconv.Atoi(string(b))
-		if err != nil {
-			t.Fatal(err)
-		}
 
-		c.Close(websocket.StatusNormalClosure, "")
-	}()
+	return url, func() {
+		wstest.Process.Kill()
+	}, nil
+}
 
-	for i := 1; i <= cases; i++ {
-		func() {
-			ctx, cancel := context.WithTimeout(ctx, time.Second*45)
-			defer cancel()
+func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
+	defer errd.Wrap(&err, "failed to get case count")
 
-			c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil)
-			if err != nil {
-				t.Fatal(err)
-			}
-			echoLoop(ctx, c)
-		}()
+	c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil)
+	if err != nil {
+		return 0, err
 	}
+	defer c.Close(websocket.StatusInternalError, "")
 
-	c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil)
+	_, r, err := c.Reader(ctx)
+	if err != nil {
+		return 0, err
+	}
+	b, err := ioutil.ReadAll(r)
+	if err != nil {
+		return 0, err
+	}
+	cases, err = strconv.Atoi(string(b))
 	if err != nil {
-		t.Fatal(err)
+		return 0, err
 	}
+
 	c.Close(websocket.StatusNormalClosure, "")
 
-	checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
+	return cases, nil
 }
 
 func checkWSTestIndex(t *testing.T, path string) {
 	wstestOut, err := ioutil.ReadFile(path)
-	if err != nil {
-		t.Fatalf("failed to read index.json: %v", err)
-	}
+	assert.Success(t, "ioutil.ReadFile", err)
 
 	var indexJSON map[string]map[string]struct {
 		Behavior      string `json:"behavior"`
 		BehaviorClose string `json:"behaviorClose"`
 	}
 	err = json.Unmarshal(wstestOut, &indexJSON)
-	if err != nil {
-		t.Fatalf("failed to unmarshal index.json: %v", err)
-	}
+	assert.Success(t, "json.Unmarshal", err)
 
-	var failed bool
 	for _, tests := range indexJSON {
 		for test, result := range tests {
-			switch result.Behavior {
-			case "OK", "NON-STRICT", "INFORMATIONAL":
-			default:
-				failed = true
-				t.Errorf("test %v failed", test)
-			}
-			switch result.BehaviorClose {
-			case "OK", "INFORMATIONAL":
-			default:
-				failed = true
-				t.Errorf("bad close behaviour for test %v", test)
-			}
+			t.Run(test, func(t *testing.T) {
+				switch result.BehaviorClose {
+				case "OK", "INFORMATIONAL":
+				default:
+					t.Errorf("bad close behaviour")
+				}
+
+				switch result.Behavior {
+				case "OK", "NON-STRICT", "INFORMATIONAL":
+				default:
+					t.Errorf("failed")
+				}
+			})
 		}
 	}
 
-	if failed {
-		path = strings.Replace(path, ".json", ".html", 1)
-		if os.Getenv("CI") == "" {
-			t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path)
-		}
+	if t.Failed() {
+		htmlPath := strings.Replace(path, ".json", ".html", 1)
+		t.Errorf("detected autobahn violation, see %q", htmlPath)
 	}
 }
+
+func unusedListenAddr() (_ string, err error) {
+	defer errd.Wrap(&err, "failed to get unused listen address")
+	l, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		return "", err
+	}
+	l.Close()
+	return l.Addr().String(), nil
+}
+
+func tempJSONFile(v interface{}) (string, error) {
+	f, err := ioutil.TempFile("", "temp.json")
+	if err != nil {
+		return "", fmt.Errorf("temp file: %w", err)
+	}
+	defer f.Close()
+
+	e := json.NewEncoder(f)
+	e.SetIndent("", "\t")
+	err = e.Encode(v)
+	if err != nil {
+		return "", fmt.Errorf("json encode: %w", err)
+	}
+
+	err = f.Close()
+	if err != nil {
+		return "", fmt.Errorf("close temp file: %w", err)
+	}
+
+	return f.Name(), nil
+}
diff --git a/close.go b/close.go
index 7ccdb17..c5c51c6 100644
--- a/close.go
+++ b/close.go
@@ -147,6 +147,8 @@ func (c *Conn) writeClose(code StatusCode, reason string) error {
 }
 
 func (c *Conn) waitCloseHandshake() error {
+	defer c.close(nil)
+
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
 	defer cancel()
 
diff --git a/compress.go b/compress.go
index 2410cb4..8c4dbe2 100644
--- a/compress.go
+++ b/compress.go
@@ -9,6 +9,14 @@ import (
 	"sync"
 )
 
+type CompressionOptions struct {
+	// Mode controls the compression mode.
+	Mode CompressionMode
+
+	// Threshold controls the minimum size of a message before compression is applied.
+	Threshold int
+}
+
 // CompressionMode controls the modes available RFC 7692's deflate extension.
 // See https://tools.ietf.org/html/rfc7692
 //
@@ -29,14 +37,8 @@ const (
 	// The message will only be compressed if greater than 512 bytes.
 	CompressionNoContextTakeover CompressionMode = iota
 
-	// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
-	// This enables reusing the sliding window from previous messages.
-	// As most WebSocket protocols are repetitive, this can be very efficient.
-	//
-	// The message will only be compressed if greater than 128 bytes.
-	//
-	// If the peer negotiates NoContextTakeover on the client or server side, it will be
-	// used instead as this is required by the RFC.
+	// Unimplemented for now due to limitations in compress/flate.
+	// See https://github.com/golang/go/issues/31514#issuecomment-569668619
 	CompressionContextTakeover
 
 	// CompressionDisabled disables the deflate extension.
diff --git a/conn.go b/conn.go
index 061c451..5ccf9f9 100644
--- a/conn.go
+++ b/conn.go
@@ -176,7 +176,7 @@ func (c *Conn) timeoutLoop() {
 	}
 }
 
-func (c *Conn) deflate() bool {
+func (c *Conn) flate() bool {
 	return c.copts != nil
 }
 
@@ -262,5 +262,8 @@ func (m *mu) TryLock() bool {
 }
 
 func (m *mu) Unlock() {
-	<-m.ch
+	select {
+	case <-m.ch:
+	default:
+	}
 }
diff --git a/go.mod b/go.mod
index 0609848..01ec18f 100644
--- a/go.mod
+++ b/go.mod
@@ -9,7 +9,5 @@ require (
 	github.com/gobwas/ws v1.0.2
 	github.com/golang/protobuf v1.3.2
 	github.com/gorilla/websocket v1.4.1
-	github.com/mattn/goveralls v0.0.4 // indirect
 	golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
-	golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea // indirect
 )
diff --git a/go.sum b/go.sum
index df11eba..864efaa 100644
--- a/go.sum
+++ b/go.sum
@@ -102,8 +102,6 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx
 github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
 github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM=
 github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
-github.com/mattn/goveralls v0.0.4 h1:/mdWfiU2y8kZ48EtgByYev/XT3W4dkTuKLOJJsh/r+o=
-github.com/mattn/goveralls v0.0.4/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
 github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
 github.com/nkovacs/streamquote v0.0.0-20170412213628-49af9bddb229/go.mod h1:0aYXnNPJ8l7uZxf45rWW1a/uME32OF0rhiYGNQ2oF2E=
 github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -129,7 +127,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU=
 golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g=
 golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -150,7 +147,6 @@ golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU
 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
 golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
 golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
-golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -209,11 +205,8 @@ golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2 h1:EtTFh6h4SAKemS+CURDMTDIANuduG5zKEXShyy18bGA=
 golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea h1:mtRJM/ln5qwEigajtnZtuARALEPOooGf5lwkM5a9tt4=
-golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
diff --git a/read.go b/read.go
index dc59f9f..517022b 100644
--- a/read.go
+++ b/read.go
@@ -79,7 +79,7 @@ func newMsgReader(c *Conn) *msgReader {
 	}
 
 	mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768)
-	if c.deflate() && mr.contextTakeover() {
+	if c.flate() && mr.flateContextTakeover() {
 		mr.initFlateReader()
 	}
 
@@ -87,30 +87,27 @@ func newMsgReader(c *Conn) *msgReader {
 }
 
 func (mr *msgReader) initFlateReader() {
-	mr.deflateReader = getFlateReader(readerFunc(mr.read))
-	mr.limitReader.r = mr.deflateReader
+	mr.flateReader = getFlateReader(readerFunc(mr.read))
+	mr.limitReader.r = mr.flateReader
 }
 
 func (mr *msgReader) close() {
 	mr.c.readMu.Lock(context.Background())
 	defer mr.c.readMu.Unlock()
 
-	if mr.deflateReader != nil {
-		putFlateReader(mr.deflateReader)
-		mr.deflateReader = nil
-	}
+	mr.returnFlateReader()
 }
 
-func (mr *msgReader) contextTakeover() bool {
+func (mr *msgReader) flateContextTakeover() bool {
 	if mr.c.client {
-		return mr.c.copts.serverNoContextTakeover
+		return !mr.c.copts.serverNoContextTakeover
 	}
-	return mr.c.copts.clientNoContextTakeover
+	return !mr.c.copts.clientNoContextTakeover
 }
 
 func (c *Conn) readRSV1Illegal(h header) bool {
 	// If compression is enabled, rsv1 is always illegal.
-	if !c.deflate() {
+	if !c.flate() {
 		return true
 	}
 	// rsv1 is only allowed on data frames beginning messages.
@@ -269,6 +266,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
 	err = fmt.Errorf("received close frame: %w", ce)
 	c.setCloseErr(err)
 	c.writeClose(ce.Code, ce.Reason)
+	c.close(err)
 	return err
 }
 
@@ -304,11 +302,11 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
 type msgReader struct {
 	c *Conn
 
-	ctx           context.Context
-	deflate       bool
-	deflateReader io.Reader
-	deflateTail   strings.Reader
-	limitReader   *limitReader
+	ctx         context.Context
+	deflate     bool
+	flateReader io.Reader
+	deflateTail strings.Reader
+	limitReader *limitReader
 
 	fin           bool
 	payloadLength int64
@@ -319,7 +317,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) {
 	mr.ctx = ctx
 	mr.deflate = h.rsv1
 	if mr.deflate {
-		if !mr.contextTakeover() {
+		if !mr.flateContextTakeover() {
 			mr.initFlateReader()
 		}
 		mr.deflateTail.Reset(deflateMessageTail)
@@ -335,8 +333,19 @@ func (mr *msgReader) setFrame(h header) {
 	mr.maskKey = h.maskKey
 }
 
-func (mr *msgReader) Read(p []byte) (_ int, err error) {
+func (mr *msgReader) Read(p []byte) (n int, err error) {
 	defer func() {
+		r := recover()
+		if r != nil {
+			if r != "ANMOL" {
+				panic(r)
+			}
+			err = io.EOF
+			if !mr.flateContextTakeover() {
+				mr.returnFlateReader()
+			}
+		}
+
 		errd.Wrap(&err, "failed to read")
 		if errors.Is(err, io.EOF) {
 			err = io.EOF
@@ -349,24 +358,27 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) {
 	}
 	defer mr.c.readMu.Unlock()
 
-	if mr.payloadLength == 0 && mr.fin {
-		if mr.c.deflate() && !mr.contextTakeover() {
-			if mr.deflateReader != nil {
-				putFlateReader(mr.deflateReader)
-				mr.deflateReader = nil
-			}
-		}
-		return 0, io.EOF
-	}
-
 	return mr.limitReader.Read(p)
 }
 
+func (mr *msgReader) returnFlateReader() {
+	if mr.flateReader != nil {
+		putFlateReader(mr.flateReader)
+		mr.flateReader = nil
+	}
+}
+
 func (mr *msgReader) read(p []byte) (int, error) {
 	if mr.payloadLength == 0 {
-		if mr.fin && mr.deflate {
-			n, _ := mr.deflateTail.Read(p)
-			return n, nil
+		if mr.fin {
+			if mr.deflate {
+				if mr.deflateTail.Len() == 0 {
+					panic("ANMOL")
+				}
+				n, _ := mr.deflateTail.Read(p)
+				return n, nil
+			}
+			return 0, io.EOF
 		}
 
 		h, err := mr.c.readLoop(mr.ctx)
diff --git a/write.go b/write.go
index 526b3b6..de20e04 100644
--- a/write.go
+++ b/write.go
@@ -55,7 +55,7 @@ func newMsgWriter(c *Conn) *msgWriter {
 	mw.trimWriter = &trimLastFourBytesWriter{
 		w: writerFunc(mw.write),
 	}
-	if c.deflate() && mw.deflateContextTakeover() {
+	if c.flate() && mw.flateContextTakeover() {
 		mw.ensureFlateWriter()
 	}
 
@@ -63,14 +63,16 @@ func newMsgWriter(c *Conn) *msgWriter {
 }
 
 func (mw *msgWriter) ensureFlateWriter() {
-	mw.flateWriter = getFlateWriter(mw.trimWriter)
+	if mw.flateWriter == nil {
+		mw.flateWriter = getFlateWriter(mw.trimWriter)
+	}
 }
 
-func (mw *msgWriter) deflateContextTakeover() bool {
+func (mw *msgWriter) flateContextTakeover() bool {
 	if mw.c.client {
-		return mw.c.copts.clientNoContextTakeover
+		return !mw.c.copts.clientNoContextTakeover
 	}
-	return mw.c.copts.serverNoContextTakeover
+	return !mw.c.copts.serverNoContextTakeover
 }
 
 func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
@@ -87,7 +89,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
 		return 0, err
 	}
 
-	if !c.deflate() {
+	if !c.flate() {
 		// Fast single frame path.
 		defer c.msgWriter.mu.Unlock()
 		return c.writeFrame(ctx, true, c.msgWriter.opcode, p)
@@ -107,11 +109,12 @@ type msgWriter struct {
 
 	mu *mu
 
-	deflate bool
-	ctx     context.Context
-	opcode  opcode
-	closed  bool
+	ctx    context.Context
+	opcode opcode
+	closed bool
 
+	// TODO pass down into writeFrame
+	flate       bool
 	trimWriter  *trimLastFourBytesWriter
 	flateWriter *flate.Writer
 }
@@ -125,7 +128,7 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
 	mw.closed = false
 	mw.ctx = ctx
 	mw.opcode = opcode(typ)
-	mw.deflate = false
+	mw.flate = false
 	return nil
 }
 
@@ -137,13 +140,14 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 		return 0, errors.New("cannot use closed writer")
 	}
 
-	if mw.c.deflate() {
-		if !mw.deflate {
-			if !mw.deflateContextTakeover() {
+	if mw.c.flate() {
+		if !mw.flate {
+			mw.flate = true
+
+			if !mw.flateContextTakeover() {
 				mw.ensureFlateWriter()
 			}
 			mw.trimWriter.reset()
-			mw.deflate = true
 		}
 
 		return mw.flateWriter.Write(p)
@@ -170,7 +174,7 @@ func (mw *msgWriter) Close() (err error) {
 	}
 	mw.closed = true
 
-	if mw.c.deflate() {
+	if mw.flate {
 		err = mw.flateWriter.Flush()
 		if err != nil {
 			return fmt.Errorf("failed to flush flate writer: %w", err)
@@ -182,9 +186,9 @@ func (mw *msgWriter) Close() (err error) {
 		return fmt.Errorf("failed to write fin frame: %w", err)
 	}
 
-	if mw.deflate && !mw.deflateContextTakeover() {
+	if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil {
 		putFlateWriter(mw.flateWriter)
-		mw.deflate = false
+		mw.flateWriter = nil
 	}
 
 	mw.mu.Unlock()
@@ -192,9 +196,10 @@ func (mw *msgWriter) Close() (err error) {
 }
 
 func (mw *msgWriter) close() {
-	if mw.c.deflate() && mw.deflateContextTakeover() {
+	if mw.flateWriter != nil && mw.flateContextTakeover() {
 		mw.mu.Lock(context.Background())
 		putFlateWriter(mw.flateWriter)
+		mw.flateWriter = nil
 	}
 }
 
@@ -236,7 +241,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
 	}
 
 	c.writeHeader.rsv1 = false
-	if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) {
+	if c.flate() && (opcode == opText || opcode == opBinary) {
 		c.writeHeader.rsv1 = true
 	}
 
-- 
GitLab