From 6f6fa430a6e88699b3b8aef5d1b8499100f3e8b9 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 30 Dec 2019 22:03:21 -0500 Subject: [PATCH] Refactor autobahn --- accept.go | 2 - autobahn_test.go | 319 ++++++++++++++++++++++++++--------------------- close.go | 2 + compress.go | 18 +-- conn.go | 7 +- go.mod | 2 - go.sum | 7 -- read.go | 74 ++++++----- write.go | 45 ++++--- 9 files changed, 260 insertions(+), 216 deletions(-) diff --git a/accept.go b/accept.go index ea7beeb..f16180f 100644 --- a/accept.go +++ b/accept.go @@ -37,8 +37,6 @@ type AcceptOptions struct { // If used incorrectly your WebSocket server will be open to CSRF attacks. InsecureSkipVerify bool - // CompressionMode sets the compression mode. - // See the docs on CompressionMode. CompressionMode CompressionMode } diff --git a/autobahn_test.go b/autobahn_test.go index 6b3b5b7..16384b2 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net" "net/http" - "net/http/httptest" "os" "os/exec" "strconv" @@ -17,9 +16,27 @@ import ( "testing" "time" + "cdr.dev/slog/sloggers/slogtest/assert" + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" ) +var excludedAutobahnCases = []string{ + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. + "6.*", "7.5.1", + + // We skip the tests related to requestMaxWindowBits as that is unimplemented due + // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 + "13.3.*", "13.4.*", "13.5.*", "13.6.*", + + "12.*", + "13.*", +} + +var autobahnCases = []string{"*"} + // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahn(t *testing.T) { t.Parallel() @@ -35,19 +52,17 @@ func TestAutobahn(t *testing.T) { func testServerAutobahn(t *testing.T) { t.Parallel() - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) - if err != nil { - t.Logf("server handshake failed: %+v", err) - return - } - echoLoop(r.Context(), c) - })) - defer s.Close() + assert.Success(t, "accept", err) + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() - spec := map[string]interface{}{ + specFile, err := tempJSONFile(map[string]interface{}{ "outdir": "ci/out/wstestServerReports", "servers": []interface{}{ map[string]interface{}{ @@ -55,92 +70,105 @@ func testServerAutobahn(t *testing.T) { "url": strings.Replace(s.URL, "http", "ws", 1), }, }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - "exclude-cases": []string{"6.*", "7.5.1"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } + "cases": autobahnCases, + "exclude-cases": excludedAutobahnCases, + }) + assert.Success(t, "tempJSONFile", err) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) defer cancel() - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + args := []string{"--mode", "fuzzingclient", "--spec", specFile} wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() - if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) - } + _, err = wstest.CombinedOutput() + assert.Success(t, "wstest", err) checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err - } - l.Close() - return l.Addr().String(), nil -} - func testClientAutobahn(t *testing.T) { t.Parallel() - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) - } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() - wsServerURL := "ws://" + serverAddr + wstestURL, closeFn, err := wstestClientServer(ctx) + assert.Success(t, "wstestClient", err) + defer closeFn() - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) + err = waitWS(ctx, wstestURL) + assert.Success(t, "waitWS", err) + + cases, err := wstestCaseCount(ctx, wstestURL) + assert.Success(t, "wstestCaseCount", err) + + t.Run("cases", func(t *testing.T) { + for i := 1; i <= cases; i++ { + i := i + t.Run("", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(ctx, time.Second*45) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) + assert.Success(t, "autobahn dial", err) + + err = echoLoop(ctx, c) + t.Logf("echoLoop: %+v", err) + }) + } + }) + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) + assert.Success(t, "dial", err) + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func waitWS(ctx context.Context, url string) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + for ctx.Err() == nil { + c, _, err := websocket.Dial(ctx, url, nil) + if err != nil { + continue + } + c.Close(websocket.StatusNormalClosure, "") + return nil } - defer specFile.Close() - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) + return ctx.Err() +} + +func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { + serverAddr, err := unusedListenAddr() if err != nil { - t.Fatalf("failed to write spec: %v", err) + return "", nil, err } - err = specFile.Close() + url = "ws://" + serverAddr + + specFile, err := tempJSONFile(map[string]interface{}{ + "url": url, + "outdir": "ci/out/wstestClientReports", + "cases": autobahnCases, + "exclude-cases": excludedAutobahnCases, + }) if err != nil { - t.Fatalf("failed to close file: %v", err) + return "", nil, fmt.Errorf("failed to write spec: %w", err) } - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer func() { + if err != nil { + cancel() + } + }() - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), + args := []string{"--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", @@ -148,101 +176,104 @@ func testClientAutobahn(t *testing.T) { wstest := exec.CommandContext(ctx, "wstest", args...) err = wstest.Start() if err != nil { - t.Fatal(err) + return "", nil, fmt.Errorf("failed to start wstest: %w", err) } - defer func() { - err := wstest.Process.Kill() - if err != nil { - t.Error(err) - } - }() - - // Let it come up. - time.Sleep(time.Second * 5) - - var cases int - func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) - } - cases, err = strconv.Atoi(string(b)) - if err != nil { - t.Fatal(err) - } - c.Close(websocket.StatusNormalClosure, "") - }() + return url, func() { + wstest.Process.Kill() + }, nil +} - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) - defer cancel() +func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { + defer errd.Wrap(&err, "failed to get case count") - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } - echoLoop(ctx, c) - }() + c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) + if err != nil { + return 0, err } + defer c.Close(websocket.StatusInternalError, "") - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) + _, r, err := c.Reader(ctx) + if err != nil { + return 0, err + } + b, err := ioutil.ReadAll(r) + if err != nil { + return 0, err + } + cases, err = strconv.Atoi(string(b)) if err != nil { - t.Fatal(err) + return 0, err } + c.Close(websocket.StatusNormalClosure, "") - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") + return cases, nil } func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } + assert.Success(t, "ioutil.ReadFile", err) var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } + assert.Success(t, "json.Unmarshal", err) - var failed bool for _, tests := range indexJSON { for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) - } + t.Run(test, func(t *testing.T) { + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + t.Errorf("bad close behaviour") + } + + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + t.Errorf("failed") + } + }) } } - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) - } + if t.Failed() { + htmlPath := strings.Replace(path, ".json", ".html", 1) + t.Errorf("detected autobahn violation, see %q", htmlPath) } } + +func unusedListenAddr() (_ string, err error) { + defer errd.Wrap(&err, "failed to get unused listen address") + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +func tempJSONFile(v interface{}) (string, error) { + f, err := ioutil.TempFile("", "temp.json") + if err != nil { + return "", fmt.Errorf("temp file: %w", err) + } + defer f.Close() + + e := json.NewEncoder(f) + e.SetIndent("", "\t") + err = e.Encode(v) + if err != nil { + return "", fmt.Errorf("json encode: %w", err) + } + + err = f.Close() + if err != nil { + return "", fmt.Errorf("close temp file: %w", err) + } + + return f.Name(), nil +} diff --git a/close.go b/close.go index 7ccdb17..c5c51c6 100644 --- a/close.go +++ b/close.go @@ -147,6 +147,8 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { } func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() diff --git a/compress.go b/compress.go index 2410cb4..8c4dbe2 100644 --- a/compress.go +++ b/compress.go @@ -9,6 +9,14 @@ import ( "sync" ) +type CompressionOptions struct { + // Mode controls the compression mode. + Mode CompressionMode + + // Threshold controls the minimum size of a message before compression is applied. + Threshold int +} + // CompressionMode controls the modes available RFC 7692's deflate extension. // See https://tools.ietf.org/html/rfc7692 // @@ -29,14 +37,8 @@ const ( // The message will only be compressed if greater than 512 bytes. CompressionNoContextTakeover CompressionMode = iota - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. - // As most WebSocket protocols are repetitive, this can be very efficient. - // - // The message will only be compressed if greater than 128 bytes. - // - // If the peer negotiates NoContextTakeover on the client or server side, it will be - // used instead as this is required by the RFC. + // Unimplemented for now due to limitations in compress/flate. + // See https://github.com/golang/go/issues/31514#issuecomment-569668619 CompressionContextTakeover // CompressionDisabled disables the deflate extension. diff --git a/conn.go b/conn.go index 061c451..5ccf9f9 100644 --- a/conn.go +++ b/conn.go @@ -176,7 +176,7 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) deflate() bool { +func (c *Conn) flate() bool { return c.copts != nil } @@ -262,5 +262,8 @@ func (m *mu) TryLock() bool { } func (m *mu) Unlock() { - <-m.ch + select { + case <-m.ch: + default: + } } diff --git a/go.mod b/go.mod index 0609848..01ec18f 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,5 @@ require ( github.com/gobwas/ws v1.0.2 github.com/golang/protobuf v1.3.2 github.com/gorilla/websocket v1.4.1 - github.com/mattn/goveralls v0.0.4 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea // indirect ) diff --git a/go.sum b/go.sum index df11eba..864efaa 100644 --- a/go.sum +++ b/go.sum @@ -102,8 +102,6 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/mattn/goveralls v0.0.4 h1:/mdWfiU2y8kZ48EtgByYev/XT3W4dkTuKLOJJsh/r+o= -github.com/mattn/goveralls v0.0.4/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/nkovacs/streamquote v0.0.0-20170412213628-49af9bddb229/go.mod h1:0aYXnNPJ8l7uZxf45rWW1a/uME32OF0rhiYGNQ2oF2E= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -129,7 +127,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -150,7 +147,6 @@ golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -209,11 +205,8 @@ golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2 h1:EtTFh6h4SAKemS+CURDMTDIANuduG5zKEXShyy18bGA= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea h1:mtRJM/ln5qwEigajtnZtuARALEPOooGf5lwkM5a9tt4= -golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/read.go b/read.go index dc59f9f..517022b 100644 --- a/read.go +++ b/read.go @@ -79,7 +79,7 @@ func newMsgReader(c *Conn) *msgReader { } mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) - if c.deflate() && mr.contextTakeover() { + if c.flate() && mr.flateContextTakeover() { mr.initFlateReader() } @@ -87,30 +87,27 @@ func newMsgReader(c *Conn) *msgReader { } func (mr *msgReader) initFlateReader() { - mr.deflateReader = getFlateReader(readerFunc(mr.read)) - mr.limitReader.r = mr.deflateReader + mr.flateReader = getFlateReader(readerFunc(mr.read)) + mr.limitReader.r = mr.flateReader } func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) defer mr.c.readMu.Unlock() - if mr.deflateReader != nil { - putFlateReader(mr.deflateReader) - mr.deflateReader = nil - } + mr.returnFlateReader() } -func (mr *msgReader) contextTakeover() bool { +func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { - return mr.c.copts.serverNoContextTakeover + return !mr.c.copts.serverNoContextTakeover } - return mr.c.copts.clientNoContextTakeover + return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is enabled, rsv1 is always illegal. - if !c.deflate() { + if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. @@ -269,6 +266,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) + c.close(err) return err } @@ -304,11 +302,11 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro type msgReader struct { c *Conn - ctx context.Context - deflate bool - deflateReader io.Reader - deflateTail strings.Reader - limitReader *limitReader + ctx context.Context + deflate bool + flateReader io.Reader + deflateTail strings.Reader + limitReader *limitReader fin bool payloadLength int64 @@ -319,7 +317,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.deflate = h.rsv1 if mr.deflate { - if !mr.contextTakeover() { + if !mr.flateContextTakeover() { mr.initFlateReader() } mr.deflateTail.Reset(deflateMessageTail) @@ -335,8 +333,19 @@ func (mr *msgReader) setFrame(h header) { mr.maskKey = h.maskKey } -func (mr *msgReader) Read(p []byte) (_ int, err error) { +func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { + r := recover() + if r != nil { + if r != "ANMOL" { + panic(r) + } + err = io.EOF + if !mr.flateContextTakeover() { + mr.returnFlateReader() + } + } + errd.Wrap(&err, "failed to read") if errors.Is(err, io.EOF) { err = io.EOF @@ -349,24 +358,27 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { } defer mr.c.readMu.Unlock() - if mr.payloadLength == 0 && mr.fin { - if mr.c.deflate() && !mr.contextTakeover() { - if mr.deflateReader != nil { - putFlateReader(mr.deflateReader) - mr.deflateReader = nil - } - } - return 0, io.EOF - } - return mr.limitReader.Read(p) } +func (mr *msgReader) returnFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { - if mr.fin && mr.deflate { - n, _ := mr.deflateTail.Read(p) - return n, nil + if mr.fin { + if mr.deflate { + if mr.deflateTail.Len() == 0 { + panic("ANMOL") + } + n, _ := mr.deflateTail.Read(p) + return n, nil + } + return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) diff --git a/write.go b/write.go index 526b3b6..de20e04 100644 --- a/write.go +++ b/write.go @@ -55,7 +55,7 @@ func newMsgWriter(c *Conn) *msgWriter { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } - if c.deflate() && mw.deflateContextTakeover() { + if c.flate() && mw.flateContextTakeover() { mw.ensureFlateWriter() } @@ -63,14 +63,16 @@ func newMsgWriter(c *Conn) *msgWriter { } func (mw *msgWriter) ensureFlateWriter() { - mw.flateWriter = getFlateWriter(mw.trimWriter) + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } } -func (mw *msgWriter) deflateContextTakeover() bool { +func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { - return mw.c.copts.clientNoContextTakeover + return !mw.c.copts.clientNoContextTakeover } - return mw.c.copts.serverNoContextTakeover + return !mw.c.copts.serverNoContextTakeover } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { @@ -87,7 +89,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return 0, err } - if !c.deflate() { + if !c.flate() { // Fast single frame path. defer c.msgWriter.mu.Unlock() return c.writeFrame(ctx, true, c.msgWriter.opcode, p) @@ -107,11 +109,12 @@ type msgWriter struct { mu *mu - deflate bool - ctx context.Context - opcode opcode - closed bool + ctx context.Context + opcode opcode + closed bool + // TODO pass down into writeFrame + flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer } @@ -125,7 +128,7 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { mw.closed = false mw.ctx = ctx mw.opcode = opcode(typ) - mw.deflate = false + mw.flate = false return nil } @@ -137,13 +140,14 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, errors.New("cannot use closed writer") } - if mw.c.deflate() { - if !mw.deflate { - if !mw.deflateContextTakeover() { + if mw.c.flate() { + if !mw.flate { + mw.flate = true + + if !mw.flateContextTakeover() { mw.ensureFlateWriter() } mw.trimWriter.reset() - mw.deflate = true } return mw.flateWriter.Write(p) @@ -170,7 +174,7 @@ func (mw *msgWriter) Close() (err error) { } mw.closed = true - if mw.c.deflate() { + if mw.flate { err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate writer: %w", err) @@ -182,9 +186,9 @@ func (mw *msgWriter) Close() (err error) { return fmt.Errorf("failed to write fin frame: %w", err) } - if mw.deflate && !mw.deflateContextTakeover() { + if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil { putFlateWriter(mw.flateWriter) - mw.deflate = false + mw.flateWriter = nil } mw.mu.Unlock() @@ -192,9 +196,10 @@ func (mw *msgWriter) Close() (err error) { } func (mw *msgWriter) close() { - if mw.c.deflate() && mw.deflateContextTakeover() { + if mw.flateWriter != nil && mw.flateContextTakeover() { mw.mu.Lock(context.Background()) putFlateWriter(mw.flateWriter) + mw.flateWriter = nil } } @@ -236,7 +241,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } c.writeHeader.rsv1 = false - if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) { + if c.flate() && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } -- GitLab