diff --git a/websocket_test.go b/websocket_test.go index 993ff9ab8a469e3dd708c222f7d1849d097e8c26..2d7db271087296ebb564f148bf11c3a5f588c1f9 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -74,6 +74,50 @@ func TestHandshake(t *testing.T) { return nil }, }, + { + name: "closeError", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = wsjson.Write(r.Context(), c, "hello") + if err != nil { + return err + } + + return nil + }, + client: func(ctx context.Context, u string) error { + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Subprotocols: []string{"meow"}, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + var m string + err = wsjson.Read(ctx, c, &m) + if err != nil { + return err + } + + if m != "hello" { + return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) + } + + _, _, err = c.Reader(ctx) + var cerr websocket.CloseError + if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("unexpected error: %+v", err) + } + + return nil + }, + }, { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error {