diff --git a/.circleci/config.yml b/.circleci/config.yml index 65b17aa06375e372e0038b394cea284d5c8c8a4c..196ec6714366370a6bb462d114cb702f2f6770bd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2 jobs: fmt: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: @@ -19,7 +19,7 @@ jobs: lint: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: @@ -36,7 +36,7 @@ jobs: test: docker: - - image: nhooyr/websocket-ci + - image: nhooyr/websocket-ci@sha256:371ca985ce2548840aeb0f8434a551708cdfe0628be722c361958e65cdded945 steps: - checkout - restore_cache: diff --git a/export_test.go b/export_test.go index 9c65360aae79d27de3001502438b91eb059accbc..fc885bffd092430565252225a4e07423333025b6 100644 --- a/export_test.go +++ b/export_test.go @@ -6,15 +6,22 @@ import ( type Addr = websocketAddr -type Header = header - const OPClose = opClose +const OPBinary = opBinary const OPPing = opPing +const OPContinuation = opContinuation func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { return c.writeFrame(ctx, fin, opcode, p) } +func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { + return c.realWriteFrame(ctx, header{ + opcode: opBinary, + payloadLength: 5, + }, make([]byte, 10)) +} + func (c *Conn) Flush() error { return c.bw.Flush() } diff --git a/websocket_test.go b/websocket_test.go index 73020f5ee430f973975df253f2135dfea0d661a1..1963ce7098460b7015146d8ea5aa14f54dd0927e 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -146,6 +146,9 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected error regarding bad origin") } + if !strings.Contains(err.Error(), "not authorized") { + return xerrors.Errorf("expected error regarding bad origin: %+v", err) + } return nil }, client: func(ctx context.Context, u string) error { @@ -158,6 +161,9 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusInternalError, "") return xerrors.New("expected handshake failure") } + if !strings.Contains(err.Error(), "403") { + return xerrors.Errorf("expected handshake failure: %+v", err) + } return nil }, }, @@ -390,8 +396,8 @@ func TestConn(t *testing.T) { nc.SetDeadline(time.Now().Add(time.Second * 15)) _, err := nc.Read(make([]byte, 1)) - if err == nil { - return xerrors.Errorf("expected error") + if err == nil || !strings.Contains(err.Error(), "unexpected frame type read") { + return xerrors.Errorf("expected error: %+v", err) } return nil @@ -426,7 +432,7 @@ func TestConn(t *testing.T) { } _, err = nc.Write([]byte{0xff}) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "websocket closed") { return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err) } @@ -586,8 +592,8 @@ func TestConn(t *testing.T) { name: "readLimit", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected error but got nil") + if err == nil || !strings.Contains(err.Error(), "read limited at") { + return xerrors.Errorf("expected error but got nil: %+v", err) } return nil }, @@ -614,7 +620,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v interface{} err := wsjson.Read(ctx, c, &v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -628,7 +634,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v interface{} err := wsjson.Read(ctx, c, &v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "failed to unmarshal json") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -641,7 +647,7 @@ func TestConn(t *testing.T) { name: "wsjson/badWrite", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -659,7 +665,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v proto.Message err := wspb.Read(ctx, c, v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "unexpected frame type") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -673,7 +679,7 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { var v timestamp.Timestamp err := wspb.Read(ctx, c, &v) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "failed to unmarshal protobuf") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -686,24 +692,7 @@ func TestConn(t *testing.T) { name: "wspb/badWrite", server: func(ctx context.Context, c *websocket.Conn) error { _, _, err := c.Read(ctx) - if err == nil { - return xerrors.Errorf("expected error: %v", err) - } - return nil - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wspb.Write(ctx, c, nil) - if err == nil { - return xerrors.Errorf("expected error: %v", err) - } - return nil - }, - }, - { - name: "wspb/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "StatusInternalError") { return xerrors.Errorf("expected error: %v", err) } return nil @@ -736,13 +725,13 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() err := c.Ping(ctx) - if err == nil { + if err == nil || !xerrors.Is(err, context.DeadlineExceeded) { return xerrors.Errorf("expected nil error: %+v", err) } return nil }, client: func(ctx context.Context, c *websocket.Conn) error { - time.Sleep(time.Second) + c.Read(ctx) return nil }, }, @@ -769,19 +758,14 @@ func TestConn(t *testing.T) { server: func(ctx context.Context, c *websocket.Conn) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - <-ctx.Done() - _, err = r.Read(make([]byte, 1)) + _, _, err := c.Read(ctx) if !xerrors.Is(err, context.DeadlineExceeded) { return xerrors.Errorf("expected deadline exceeded error: %+v", err) } return nil }, client: func(ctx context.Context, c *websocket.Conn) error { - time.Sleep(time.Second) + c.Read(ctx) return nil }, }, @@ -912,7 +896,7 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) - if err == nil { + if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { return xerrors.Errorf("expected non nil error: %v", err) } return nil @@ -942,11 +926,57 @@ func TestConn(t *testing.T) { return err } _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "previous message not read to completion") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) + if err != nil { + return err + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, err = w.Write([]byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + err = c.Flush() + if err != nil { + return xerrors.Errorf("failed to flush: %w", err) + } + _, _, err = c.Read(ctx) if err == nil { return xerrors.Errorf("expected non nil error: %v", err) } return nil }, + }, + { + name: "newMessageInFragmentedMessage", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + p := make([]byte, 10) + _, err = io.ReadFull(r, p) + if err != nil { + return err + } + _, _, err = c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, client: func(ctx context.Context, c *websocket.Conn) error { w, err := c.Writer(ctx, websocket.MessageBinary) if err != nil { @@ -960,6 +990,83 @@ func TestConn(t *testing.T) { if err != nil { return xerrors.Errorf("failed to flush: %w", err) } + _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + _, _, err = c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + }, + { + name: "continuationFrameWithoutDataFrame", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Reader(ctx) + if err == nil || !strings.Contains(err.Error(), "received continuation frame not after data") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, false, websocket.OPContinuation, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } + return nil + }, + }, + { + name: "readBeforeEOF", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + var v interface{} + d := json.NewDecoder(r) + err = d.Decode(&v) + if err != nil { + return err + } + _, b, err := c.Read(ctx) + if err != nil { + return err + } + if string(b) != "hi" { + return xerrors.Errorf("expected hi but got %q", string(b)) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, "hi") + if err != nil { + return err + } + return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + }, + }, + { + name: "newMessageInFragmentedMessage2", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + p := make([]byte, 11) + _, err = io.ReadFull(r, p) + if err == nil || !strings.Contains(err.Error(), "received new data message without finishing") { + return xerrors.Errorf("expected non nil error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + w, err := c.Writer(ctx, websocket.MessageBinary) + if err != nil { + return err + } _, err = w.Write([]byte(strings.Repeat("x", 10))) if err != nil { return xerrors.Errorf("expected non nil error") @@ -968,6 +1075,10 @@ func TestConn(t *testing.T) { if err != nil { return xerrors.Errorf("failed to flush: %w", err) } + _, err = c.WriteFrame(ctx, true, websocket.OPBinary, []byte(strings.Repeat("x", 10))) + if err != nil { + return xerrors.Errorf("expected non nil error") + } _, _, err = c.Read(ctx) if err == nil { return xerrors.Errorf("expected non nil error: %v", err) @@ -975,6 +1086,41 @@ func TestConn(t *testing.T) { return nil }, }, + { + name: "doubleRead", + server: func(ctx context.Context, c *websocket.Conn) error { + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + _, err = ioutil.ReadAll(r) + if err != nil { + return err + } + _, err = r.Read(make([]byte, 1)) + if err == nil || !strings.Contains(err.Error(), "cannot use EOFed reader") { + return xerrors.Errorf("expected non nil error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageBinary, []byte("hi")) + }, + }, + { + name: "eofInPayload", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "failed to read frame payload") { + return xerrors.Errorf("expected failed to read frame payload: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteHalfFrame(ctx) + return err + }, + }, } for _, tc := range testCases { tc := tc @@ -990,8 +1136,7 @@ func TestConn(t *testing.T) { return err } defer c.Close(websocket.StatusInternalError, "") - tc.server(r.Context(), c) - return nil + return tc.server(r.Context(), c) }, tls) defer closeFn()