diff --git a/contrib/codecs/websocket/client.go b/contrib/codecs/websocket/client.go index 86b7f35761b4d278f10e66a5f722783f4b4b5e1e..03fcaca9f7f81703dcb1e6596d2ec7f6033cf0ad 100644 --- a/contrib/codecs/websocket/client.go +++ b/contrib/codecs/websocket/client.go @@ -1,54 +1,28 @@ package websocket import ( + "gfx.cafe/open/jrpc/contrib/codecs/rdwr" + "context" - jrpc2 "gfx.cafe/open/jrpc/pkg/codec" - "sync" - "gfx.cafe/open/jrpc" "nhooyr.io/websocket" ) type Client struct { - conn *websocket.Conn - reconnectFunc reconnectFunc - - mu sync.RWMutex + *rdwr.Client + conn *websocket.Conn } -type reconnectFunc func(ctx context.Context) (*websocket.Conn, error) - -func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { - conn, err := connect(initctx) - if err != nil { - return nil, err +func newClient(conn *websocket.Conn) (*Client, error) { + conn.SetReadLimit(WsMessageSizeLimit) + netConn := websocket.NetConn(context.Background(), conn, websocket.MessageText) + c := &Client{ + Client: rdwr.NewClient(netConn, netConn, nil), + conn: conn, } - c := &Client{} - c.conn = conn - c.reconnectFunc = connect return c, nil } -func (c *Client) Do(ctx context.Context, result any, method string, params any) error { - panic("not implemented") // TODO: Implement -} - -func (c *Client) BatchCall(ctx context.Context, b ...jrpc2.BatchElem) error { - panic("not implemented") // TODO: Implement -} - -func (c *Client) SetHeader(key string, value string) { - panic("not implemented") // TODO: Implement -} - func (c *Client) Close() error { - panic("not implemented") // TODO: Implement -} - -func (c *Client) Notify(ctx context.Context, method string, args ...any) error { - panic("not implemented") // TODO: Implement -} - -func (c *Client) Subscribe(ctx context.Context, namespace string, channel any, args ...any) (*jrpc.ClientSubscription, error) { - panic("not implemented") // TODO: Implement + return c.conn.Close(websocket.StatusNormalClosure, "") } diff --git a/contrib/codecs/websocket/client_example_test.go b/contrib/codecs/websocket/client_example_test.go index 09edfa51c793d05c5367f7eeacb334470c238447..27994a80f24f2b3d47690d6849cebf9b8a52f91c 100644 --- a/contrib/codecs/websocket/client_example_test.go +++ b/contrib/codecs/websocket/client_example_test.go @@ -1,13 +1,8 @@ package websocket_test -import ( - "context" - "fmt" - "time" +/* - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/rpc" -) +Re enable this test when subscriptions // In this example, our client wishes to track the latest 'block number' // known to the server. The server supports two methods: @@ -71,3 +66,5 @@ func subscribeBlocks(client *rpc.Client, subch chan Block) { // the connection. fmt.Println("connection lost: ", <-sub.Err()) } + +*/ diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go index 708bc8cb5dd4392e386ecaefe121a43791aa892b..4be78b854c7bad129e54cca05caf454d43ae07e6 100644 --- a/contrib/codecs/websocket/codec.go +++ b/contrib/codecs/websocket/codec.go @@ -1 +1,78 @@ package websocket + +import ( + "context" + "net/http" + "time" + + "nhooyr.io/websocket" + + "gfx.cafe/open/jrpc/contrib/codecs/rdwr" + "gfx.cafe/open/jrpc/pkg/codec" +) + +type Codec struct { + *rdwr.Codec + conn *websocket.Conn + + i codec.PeerInfo +} + +func newWebsocketCodec(ctx context.Context, conn *websocket.Conn, host string, req http.Header) *Codec { + conn.SetReadLimit(WsMessageSizeLimit) + netConn := websocket.NetConn(ctx, conn, websocket.MessageText) + c := &Codec{ + Codec: rdwr.NewCodec(netConn, netConn, nil), + conn: conn, + } + c.i.Transport = "ws" + // Fill in connection details. + c.i.HTTP.Host = host + // traefik proxy protocol headers + c.i.HTTP.Origin = req.Get("X-Real-Ip") + if c.i.HTTP.Origin == "" { + c.i.HTTP.Origin = req.Get("X-Forwarded-For") + } + // origin header fallback + if c.i.HTTP.Origin == "" { + c.i.HTTP.Origin = req.Get("origin") + } + c.i.RemoteAddr = c.i.HTTP.Origin + c.i.HTTP.UserAgent = req.Get("User-Agent") + c.i.HTTP.Headers = req + // Start pinger. + go heartbeat(ctx, conn, WsPingInterval) + return c +} + +func heartbeat(ctx context.Context, c *websocket.Conn, d time.Duration) { + t := time.NewTimer(d) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + } + err := c.Ping(ctx) + if err != nil { + return + } + t.Reset(time.Minute) + } +} + +func (c *Codec) PeerInfo() codec.PeerInfo { + return c.i +} + +func (c *Codec) Close() error { + if err := c.Codec.Close(); err != nil { + return err + } + return c.conn.Close(websocket.StatusNormalClosure, "") +} + +func (c *Codec) RemoteAddr() string { + return c.i.RemoteAddr +} diff --git a/contrib/codecs/websocket/codec_test.go b/contrib/codecs/websocket/codec_test.go new file mode 100644 index 0000000000000000000000000000000000000000..dbb3d6def53e7eccc4d87ba74cbc371801106c76 --- /dev/null +++ b/contrib/codecs/websocket/codec_test.go @@ -0,0 +1,25 @@ +package websocket + +import ( + "context" + "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/open/jrpc/pkg/jrpctest" + "gfx.cafe/open/jrpc/pkg/server" + "github.com/stretchr/testify/require" + "net/http/httptest" + "testing" +) + +func TestBasicSuite(t *testing.T) { + jrpctest.RunBasicTestSuite(t, jrpctest.BasicTestSuiteArgs{ + ServerMaker: func() (*server.Server, jrpctest.ClientMaker, func()) { + s := jrpctest.NewServer() + hsrv := httptest.NewServer(&Server{Server: s}) + return s, func() codec.Conn { + conn, err := DialWebsocket(context.Background(), hsrv.URL, "") + require.NoError(t, err) + return conn + }, hsrv.Close + }, + }) +} diff --git a/contrib/codecs/websocket/dial.go b/contrib/codecs/websocket/dial.go index 3f50e5d6269bdd78deef5409c6ecce197542b0b7..5336370ea648d880da9a257b554e9d28c16cba7e 100644 --- a/contrib/codecs/websocket/dial.go +++ b/contrib/codecs/websocket/dial.go @@ -2,6 +2,9 @@ package websocket import ( "context" + "encoding/base64" + "net/http" + "net/url" "nhooyr.io/websocket" ) @@ -31,15 +34,26 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, opts return nil, err } opts.HTTPHeader = header - return newClient(ctx, func(cctx context.Context) (*websocket.Conn, error) { - conn, resp, err := websocket.Dial(cctx, endpoint, opts) - if err != nil { - hErr := WsHandshakeError{err: err} - if resp != nil { - hErr.status = resp.Status - } - return nil, hErr - } - return conn, err - }) + conn, _, err := websocket.Dial(ctx, endpoint, opts) + if err != nil { + return nil, err + } + return newClient(conn) +} + +func WsClientHeaders(endpoint, origin string) (string, http.Header, error) { + endpointURL, err := url.Parse(endpoint) + if err != nil { + return endpoint, nil, err + } + header := make(http.Header) + if origin != "" { + header.Add("origin", origin) + } + if endpointURL.User != nil { + b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) + header.Add("authorization", "Basic "+b64auth) + endpointURL.User = nil + } + return endpointURL.String(), header, nil } diff --git a/contrib/codecs/websocket/handler.go b/contrib/codecs/websocket/handler.go new file mode 100644 index 0000000000000000000000000000000000000000..db03a3b52fb5e06b230814ead127b388144291d2 --- /dev/null +++ b/contrib/codecs/websocket/handler.go @@ -0,0 +1,46 @@ +package websocket + +import ( + "net/http" + + "nhooyr.io/websocket" + + "gfx.cafe/open/jrpc/pkg/server" +) + +type Server struct { + Server *server.Server +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.Server == nil { + http.Error(w, "no server set", http.StatusInternalServerError) + return + } + conn, err := websocket.Accept(w, r, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + c := newWebsocketCodec(r.Context(), conn, "", r.Header) + s.Server.ServeCodec(r.Context(), c) +} + +// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. +// +// allowedOrigins should be a comma-separated list of allowed origin URLs. +// To allow connections with any origin, pass "*". +func WebsocketHandler(s *server.Server, allowedOrigins []string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: allowedOrigins, + CompressionMode: websocket.CompressionContextTakeover, + CompressionThreshold: 4096, + }) + if err != nil { + return + } + codec := newWebsocketCodec(r.Context(), conn, r.Host, r.Header) + s.ServeCodec(r.Context(), codec) + }) +} diff --git a/contrib/codecs/websocket/websocket.go b/contrib/codecs/websocket/websocket.go deleted file mode 100644 index ea2d3ea936f326220408c829c04061803b3f9239..0000000000000000000000000000000000000000 --- a/contrib/codecs/websocket/websocket.go +++ /dev/null @@ -1,164 +0,0 @@ -package websocket - -import ( - "context" - "encoding/base64" - "encoding/json" - "gfx.cafe/open/jrpc/contrib/codecs/websocket/wsjson" - codec2 "gfx.cafe/open/jrpc/pkg/codec" - "gfx.cafe/open/jrpc/pkg/server" - "net/http" - "net/url" - "time" - - "nhooyr.io/websocket" -) - -// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. -// -// allowedOrigins should be a comma-separated list of allowed origin URLs. -// To allow connections with any origin, pass "*". -func WebsocketHandler(s *server.Server, allowedOrigins []string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - OriginPatterns: allowedOrigins, - CompressionMode: websocket.CompressionContextTakeover, - CompressionThreshold: 4096, - }) - if err != nil { - return - } - codec := newWebsocketCodec(r.Context(), conn, r.Host, r.Header) - s.ServeCodec(codec) - }) -} - -func NewHandshakeError(err error, status string) error { - return &wsHandshakeError{err, status} -} - -type wsHandshakeError struct { - err error - status string -} - -func (e wsHandshakeError) Error() string { - s := e.err.Error() - if e.status != "" { - s += " (HTTP status " + e.status + ")" - } - return s -} - -func WsClientHeaders(endpoint, origin string) (string, http.Header, error) { - endpointURL, err := url.Parse(endpoint) - if err != nil { - return endpoint, nil, err - } - header := make(http.Header) - if origin != "" { - header.Add("origin", origin) - } - if endpointURL.User != nil { - b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) - header.Add("authorization", "Basic "+b64auth) - endpointURL.User = nil - } - return endpointURL.String(), header, nil -} - -type websocketCodec struct { - conn *websocket.Conn - info codec2.PeerInfo - - pingReset chan struct{} - - closed chan any -} - -// if there is more than one message, it is a batch request -func (w *websocketCodec) ReadBatch(ctx context.Context) (msgs json.RawMessage, err error) { - w.conn.SetReadLimit(WsMessageSizeLimit) - err = wsjson.Read(ctx, w.conn, &msgs) - if err != nil { - return nil, err - } - return msgs, nil -} - -// Closed returns a channel which is closed when the connection is closed. -func (w *websocketCodec) Closed() <-chan any { - return w.closed -} - -// RemoteAddr returns the peer address of the connection. -func (w *websocketCodec) RemoteAddr() string { - return w.info.RemoteAddr -} - -func heartbeat(ctx context.Context, c *websocket.Conn, d time.Duration) { - t := time.NewTimer(d) - defer t.Stop() - for { - select { - case <-ctx.Done(): - return - case <-t.C: - } - err := c.Ping(ctx) - if err != nil { - return - } - t.Reset(time.Minute) - } -} - -func newWebsocketCodec(ctx context.Context, c *websocket.Conn, host string, req http.Header) codec2.ReaderWriter { - wc := &websocketCodec{ - conn: c, - pingReset: make(chan struct{}, 1), - info: codec2.PeerInfo{ - Transport: "ws", - }, - closed: make(chan any), - } - // Fill in connection details. - wc.info.HTTP.Host = host - // traefik proxy protocol headers - wc.info.HTTP.Origin = req.Get("X-Real-Ip") - if wc.info.HTTP.Origin == "" { - wc.info.HTTP.Origin = req.Get("X-Forwarded-For") - } - // origin header fallback - if wc.info.HTTP.Origin == "" { - wc.info.HTTP.Origin = req.Get("origin") - } - wc.info.RemoteAddr = wc.info.HTTP.Origin - wc.info.HTTP.UserAgent = req.Get("User-Agent") - wc.info.HTTP.Headers = req - // Start pinger. - go heartbeat(ctx, c, WsPingInterval) - return wc -} - -func (wc *websocketCodec) Close() error { - wc.conn.CloseRead(context.Background()) - close(wc.closed) - return nil -} - -func (wc *websocketCodec) PeerInfo() codec2.PeerInfo { - return wc.info -} - -func (wc *websocketCodec) WriteJSON(ctx context.Context, v any) error { - err := wsjson.Write(ctx, wc.conn, v) - if err == nil { - // Notify pingLoop to delay the next idle ping. - select { - case wc.pingReset <- struct{}{}: - default: - } - } - return err -} diff --git a/contrib/codecs/websocket/websocket_test.go b/contrib/codecs/websocket/websocket_test.go index 0009913986d463e1214ab71c0db501b5b21a93ca..360ce028f0a0cecdac66ba64bb207a3581ad9ebc 100644 --- a/contrib/codecs/websocket/websocket_test.go +++ b/contrib/codecs/websocket/websocket_test.go @@ -2,11 +2,10 @@ package websocket_test import ( "context" - "errors" - websocket2 "gfx.cafe/open/jrpc/contrib/codecs/websocket" + "gfx.cafe/open/jrpc/contrib/codecs/websocket" "gfx.cafe/open/jrpc/contrib/jmux" "gfx.cafe/open/jrpc/pkg/codec" - jrpctest2 "gfx.cafe/open/jrpc/pkg/jrpctest" + "gfx.cafe/open/jrpc/pkg/jrpctest" "gfx.cafe/open/jrpc/pkg/server" "net/http/httptest" "strings" @@ -16,7 +15,7 @@ import ( func TestWebsocketClientHeaders(t *testing.T) { t.Parallel() - endpoint, header, err := websocket2.WsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") + endpoint, header, err := websocket.WsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") if err != nil { t.Fatalf("wsGetConfig failed: %s", err) } @@ -36,25 +35,25 @@ func TestWebsocketOriginCheck(t *testing.T) { t.Parallel() var ( - srv = jrpctest2.NewTestServer() - httpsrv = httptest.NewServer(websocket2.WebsocketHandler(srv, []string{"http://example.com"})) + srv = jrpctest.NewServer() + httpsrv = httptest.NewServer(websocket.WebsocketHandler(srv, []string{"http://example.com"})) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() defer httpsrv.Close() - client, err := websocket2.DialWebsocket(context.Background(), wsURL, "http://ekzample.com") + client, err := websocket.DialWebsocket(context.Background(), wsURL, "http://ekzample.com") if err == nil { client.Close() t.Fatal("no error for wrong origin") } - wantErr := websocket2.NewHandshakeError(errors.New("403"), "403 Forbidden") - if !strings.Contains(err.Error(), wantErr.Error()) { + wantErr := "expected handshake response status code 101 but got 403" + if !strings.Contains(err.Error(), wantErr) { t.Fatalf("wrong error for wrong origin: got: '%q', want: '%s'", err, wantErr) } // Connections without origin header should work. - client, err = websocket2.DialWebsocket(context.Background(), wsURL, "") + client, err = websocket.DialWebsocket(context.Background(), wsURL, "") if err != nil { t.Fatalf("error for empty origin: %v", err) } @@ -66,22 +65,22 @@ func TestWebsocketLargeCall(t *testing.T) { t.Parallel() var ( - srv = jrpctest2.NewTestServer() - httpsrv = httptest.NewServer(websocket2.WebsocketHandler(srv, []string{"*"})) + srv = jrpctest.NewServer() + httpsrv = httptest.NewServer(websocket.WebsocketHandler(srv, []string{"*"})) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() defer httpsrv.Close() - client, err := websocket2.DialWebsocket(context.Background(), wsURL, "") + client, err := websocket.DialWebsocket(context.Background(), wsURL, "") if err != nil { t.Fatalf("can't dial: %v", err) } defer client.Close() // This call sends slightly less than the limit and should work. - var result jrpctest2.EchoResult - arg := strings.Repeat("x", websocket2.MaxRequestContentLength-200) + var result jrpctest.EchoResult + arg := strings.Repeat("x", websocket.MaxRequestContentLength-200) if err := client.Do(nil, &result, "test_echo", []any{arg, 1}); err != nil { t.Fatalf("valid call didn't work: %v", err) } @@ -90,7 +89,7 @@ func TestWebsocketLargeCall(t *testing.T) { } // This call sends twice the allowed size and shouldn't work. - arg = strings.Repeat("x", websocket2.MaxRequestContentLength*2) + arg = strings.Repeat("x", websocket.MaxRequestContentLength*2) err = client.Do(nil, &result, "test_echo", []any{arg}) if err == nil { t.Fatal("no error for too large call") @@ -99,15 +98,15 @@ func TestWebsocketLargeCall(t *testing.T) { func TestWebsocketPeerInfo(t *testing.T) { var ( - s = jrpctest2.NewTestServer() - ts = httptest.NewServer(websocket2.WebsocketHandler(s, []string{"origin.example.com"})) + s = jrpctest.NewServer() + ts = httptest.NewServer(websocket.WebsocketHandler(s, []string{"origin.example.com"})) tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") ) defer s.Stop() defer ts.Close() ctx := context.Background() - c, err := websocket2.DialWebsocket(ctx, tsurl, "http://origin.example.com") + c, err := websocket.DialWebsocket(ctx, tsurl, "http://origin.example.com") if err != nil { t.Fatal(err) } @@ -137,16 +136,18 @@ func TestClientWebsocketLargeMessage(t *testing.T) { mux := jmux.NewMux() var ( srv = server.NewServer(mux) - httpsrv = httptest.NewServer(websocket2.WebsocketHandler(srv, nil)) + httpsrv = httptest.NewServer(websocket.WebsocketHandler(srv, nil)) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") ) defer srv.Stop() defer httpsrv.Close() - respLength := websocket2.WsMessageSizeLimit - 50 - mux.RegisterStruct("test", jrpctest2.LargeRespService{Length: respLength}) + respLength := websocket.WsMessageSizeLimit - 50 + if err := mux.RegisterStruct("test", jrpctest.LargeRespService{Length: respLength}); err != nil { + t.Fatal(err) + } - c, err := websocket2.DialWebsocket(context.Background(), wsURL, "") + c, err := websocket.DialWebsocket(context.Background(), wsURL, "") if err != nil { t.Fatal(err) } diff --git a/readme.md b/readme.md index bb48c318fe6e71ea6d91a3ea01c251b06c3bb7b9..fb762f869f5073d3af9bb9e0f3edca27daccfc8a 100644 --- a/readme.md +++ b/readme.md @@ -26,13 +26,20 @@ contrib/ - packages that add to jrpc codecs/ - client and server transport implementations codecs.go - dialers for all finished codecs http/ - http based codec - codec_test.go - general tests that all must pass - client.go - codec.Conn implementation - codec.go - codec.ReaderWriter implementaiton - const.go - constants - handler.go - http handler - http_test.go - http specific tests - websocket/ - WIP: websocket basec codec + codec_test.go - general tests that all must pass + client.go - codec.Conn implementation + codec.go - codec.ReaderWriter implementaiton + const.go - constants + handler.go - http handler + http_test.go - http specific tests + websocket/ - websocket basec codec + codec_test.go - general tests that all must pass + client.go - codec.Conn implementation + codec.go - codec.ReadWriter implementation + const.go - constants + dial.go - websocket dialer + handler.go - http handler + websocket_test.go - websocket specific tests inproc/ - WIP: inproc based codec ipc/ - WIP: ipc based codec stdio/ - WIP: stdio based codec (variation of ipc)