diff --git a/jrpc.go b/jrpc.go index 404c88bf3e9739cf6b59945d0e27825fc760e704..766ab261c16feb1ae15a08045c450bb939fb0d01 100644 --- a/jrpc.go +++ b/jrpc.go @@ -53,3 +53,24 @@ type BatchElem struct { // unmarshaling into Result fails. It is not set for I/O errors. Error error } + +type clientContextKey struct{} + +// ClientFromContext retrieves the client from the context, if any. This can be used to perform +// 'reverse calls' in a handler method. +func ContextWithConn(ctx context.Context, c Conn) context.Context { + client, _ := ctx.Value(clientContextKey{}).(Conn) + return context.WithValue(ctx, clientContextKey{}, client) +} + +// ClientFromContext retrieves the client from the context, if any. This can be used to perform +// 'reverse calls' in a handler method. +func ConnFromContext(ctx context.Context) (Conn, bool) { + client, ok := ctx.Value(clientContextKey{}).(Conn) + return client, ok +} + +func StreamingConnFromContext(ctx context.Context) (StreamingConn, bool) { + client, ok := ctx.Value(clientContextKey{}).(StreamingConn) + return client, ok +} diff --git a/pkg/codec/codecs/http/client.go b/pkg/codec/codecs/http/client.go index 4cc3c8f78e2e9a475c6256b82ed837b8f8b04dfb..978857f26969eea64c53c8c7091ede878af74bcd 100644 --- a/pkg/codec/codecs/http/client.go +++ b/pkg/codec/codecs/http/client.go @@ -5,12 +5,15 @@ import ( "context" "encoding/json" "errors" + "fmt" + "io" "net/http" "sync/atomic" "time" "gfx.cafe/open/jrpc/pkg/clientutil" "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/util/go/bufpool" "gfx.cafe/open/jrpc" ) @@ -37,35 +40,49 @@ type Client struct { c *http.Client id atomic.Int64 + + headers http.Header +} + +func DialHTTP(target string) (*Client, error) { + return Dial(nil, http.DefaultClient, target) } func Dial(ctx context.Context, client *http.Client, target string) (*Client, error) { - return &Client{remote: target, c: client}, nil + return &Client{remote: target, c: client, headers: http.Header{}}, nil +} + +func (c *Client) SetHeader(key string, value string) { + c.headers.Set(key, value) } func (c *Client) Do(ctx context.Context, result any, method string, params any) error { req := jrpc.NewRequestInt(ctx, int(c.id.Add(1)), method, params) - dat, err := req.MarshalJSON() - if err != nil { - return err - } - resp, err := c.c.Post(c.remote, "application/json", bytes.NewBuffer(dat)) + resp, err := c.post(req) if err != nil { return err } defer resp.Body.Close() + if resp.StatusCode != 200 { + b, _ := io.ReadAll(resp.Body) + return &codec.HTTPError{ + StatusCode: resp.StatusCode, + Status: resp.Status, + Body: b, + } + } // TODO: this can be reused msg := clientutil.GetMessage() defer clientutil.PutMessage(msg) err = json.NewDecoder(resp.Body).Decode(&msg) if err != nil { - return err + return fmt.Errorf("decode json: %w", err) } if msg.Error != nil { return err } - if result != nil { - err = json.Unmarshal(msg.Result, &msg) + if result != nil && len(msg.Result) > 0 { + err = json.Unmarshal(msg.Result, &result) if err != nil { return err } @@ -73,16 +90,34 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any) return nil } -func (c *Client) Notify(ctx context.Context, method string, params any) error { - req := jrpc.NewRequestInt(ctx, int(c.id.Add(1)), method, params) - dat, err := req.MarshalJSON() +func (c *Client) post(req *jrpc.Request) (*http.Response, error) { + //TODO: use buffer for this + buf := bufpool.GetStd() + defer bufpool.PutStd(buf) + buf.Reset() + err := json.NewEncoder(buf).Encode(req) if err != nil { - return err + return nil, err + } + hreq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, c.remote, buf) + if err != nil { + return nil, err } - _, err = c.c.Post(c.remote, "application/json", bytes.NewBuffer(dat)) + for k, v := range c.headers { + for _, vv := range v { + hreq.Header.Add(k, vv) + } + } + return c.c.Do(hreq) +} + +func (c *Client) Notify(ctx context.Context, method string, params any) error { + req := jrpc.NewNotification(ctx, method, params) + resp, err := c.post(req) if err != nil { return err } + resp.Body.Close() return err } diff --git a/pkg/codec/codecs/http/client_test.go b/pkg/codec/codecs/http/client_test.go index ac8c0904f9b3df90a4b5b91835fb46e785b1f7b3..b99934a02a72acde2954158959032bbbc347d386 100644 --- a/pkg/codec/codecs/http/client_test.go +++ b/pkg/codec/codecs/http/client_test.go @@ -1,459 +1,24 @@ package http import ( - "context" - "fmt" - "math/rand" - "net" - "net/http" "net/http/httptest" - "os" - "reflect" - "runtime" - "sync" "testing" - "time" - "github.com/davecgh/go-spew/spew" - "tuxpa.in/a/zlog" - "tuxpa.in/a/zlog/log" + "gfx.cafe/open/jrpc" + "gfx.cafe/open/jrpc/pkg/jrpctest" + "github.com/stretchr/testify/require" ) -func init() { - zlog.SetGlobalLevel(zlog.FatalLevel) -} - -func TestClientRequest(t *testing.T) { - server := newTestServer() - defer server.Stop() - client := DialInProc(server) - defer client.Close() - - var resp echoResult - if err := client.Call(nil, &resp, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(resp, echoResult{"hello", 10, &echoArgs{"world"}}) { - t.Errorf("incorrect result %#v", resp) - } -} - -func TestClientResponseType(t *testing.T) { - server := newTestServer() - defer server.Stop() - client := DialInProc(server) - defer client.Close() - - if err := client.Call(nil, nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { - t.Errorf("Passing nil as result should be fine, but got an error: %v", err) - } - var resultVar echoResult - // Note: passing the var, not a ref - err := client.Call(nil, resultVar, "test_echo", "hello", 10, &echoArgs{"world"}) - if err == nil { - t.Error("Passing a var as result should be an error") - } -} - -// This test checks that server-returned errors with code and data come out of Client.Call. -func TestClientErrorData(t *testing.T) { - server := newTestServer() - defer server.Stop() - client := DialInProc(server) - defer client.Close() - - var resp any - err := client.Call(nil, &resp, "test_returnError") - if err == nil { - t.Fatal("expected error") - } - - // Check code. - if e, ok := err.(Error); !ok { - t.Fatalf("client did not return rpc.Error, got %#v", e) - } else if e.ErrorCode() != (testError{}.ErrorCode()) { - t.Fatalf("wrong error code %d, want %d", e.ErrorCode(), testError{}.ErrorCode()) - } - // Check data. - if e, ok := err.(DataError); !ok { - t.Fatalf("client did not return rpc.DataError, got %#v", e) - } else if e.ErrorData() != (testError{}.ErrorData()) { - t.Fatalf("wrong error data %#v, want %#v", e.ErrorData(), testError{}.ErrorData()) - } -} - -func TestClientBatchRequest(t *testing.T) { - server := newTestServer() - defer server.Stop() - client := DialInProc(server) - defer client.Close() - batch := []BatchElem{ - { - Method: "test_echo", - Args: []any{"hello", 10, &echoArgs{"world"}}, - Result: new(echoResult), - }, - { - Method: "test_echo", - Args: []any{"hello2", 11, &echoArgs{"world"}}, - Result: new(echoResult), - }, - { - Method: "no_such_method", - Args: []any{1, 2, 3}, - Result: new(int), - }, - } - if err := client.BatchCall(nil, batch...); err != nil { - t.Fatal(err) - } - wantResult := []BatchElem{ - { - Method: "test_echo", - Args: []any{"hello", 10, &echoArgs{"world"}}, - Result: &echoResult{"hello", 10, &echoArgs{"world"}}, +func TestBasicSuite(t *testing.T) { + jrpctest.RunBasicTestSuite(t, jrpctest.BasicTestSuiteArgs{ + ServerMaker: func() (*jrpc.Server, jrpctest.ClientMaker, func()) { + s := jrpctest.NewServer() + hsrv := httptest.NewServer(&Server{Server: s}) + return s, func() jrpc.Conn { + conn, err := DialHTTP(hsrv.URL) + require.NoError(t, err) + return conn + }, hsrv.Close }, - { - Method: "test_echo", - Args: []any{"hello2", 11, &echoArgs{"world"}}, - Result: &echoResult{"hello2", 11, &echoArgs{"world"}}, - }, - { - Method: "no_such_method", - Args: []any{1, 2, 3}, - Result: new(int), - Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"}, - }, - } - if !reflect.DeepEqual(batch, wantResult) { - t.Errorf("batch results mismatch:\ngot %swant %s", spew.Sdump(batch), spew.Sdump(wantResult)) - } -} - -func TestClientNotify(t *testing.T) { - server := newTestServer() - defer server.Stop() - client := DialInProc(server) - defer client.Close() - - if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { - t.Fatal(err) - } -} - -// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } -func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } -func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } -func TestClientCancelIPC(t *testing.T) { testClientCancel("ipc", t) } - -// This test checks that requests made through Call can be canceled by canceling -// the context. -func testClientCancel(transport string, t *testing.T) { - // These tests take a lot of time, run them all at once. - // You probably want to run with -parallel 1 or comment out - // the call to t.Parallel if you enable the logging. - t.Parallel() - - server := newTestServer() - defer server.Stop() - - // What we want to achieve is that the context gets canceled - // at various stages of request processing. The interesting cases - // are: - // - cancel during dial - // - cancel while performing a HTTP request - // - cancel while waiting for a response - // - // To trigger those, the times are chosen such that connections - // are killed within the deadline for every other call (maxKillTimeout - // is 2x maxCancelTimeout). - // - // Once a connection is dead, there is a fair chance it won't connect - // successfully because the accept is delayed by 1s. - maxContextCancelTimeout := 300 * time.Millisecond - fl := &flakeyListener{ - maxAcceptDelay: 1 * time.Second, - maxKillTimeout: 600 * time.Millisecond, - } - - var client *Client - switch transport { - case "ws", "http": - c, hs := httpTestClient(server, transport, fl) - defer hs.Close() - client = c - case "ipc": - c, l := ipcTestClient(server, fl) - defer l.Close() - client = c - default: - panic("unknown transport: " + transport) - } - - // The actual test starts here. - var ( - wg sync.WaitGroup - nreqs = 10 - ncallers = 10 - ) - caller := func(index int) { - defer wg.Done() - for i := 0; i < nreqs; i++ { - var ( - ctx context.Context - cancel func() - timeout = time.Duration(rand.Int63n(int64(maxContextCancelTimeout))) - ) - if index < ncallers/2 { - // For half of the callers, create a context without deadline - // and cancel it later. - ctx, cancel = context.WithCancel(context.Background()) - time.AfterFunc(timeout, cancel) - } else { - // For the other half, create a context with a deadline instead. This is - // different because the context deadline is used to set the socket write - // deadline. - ctx, cancel = context.WithTimeout(context.Background(), timeout) - } - - // Now perform a call with the context. - // The key thing here is that no call will ever complete successfully. - err := client.Call(ctx, nil, "test_block") - switch { - case err == nil: - _, hasDeadline := ctx.Deadline() - t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) - // default: - // t.Logf("got expected error with %v wait time: %v", timeout, err) - } - cancel() - } - } - wg.Add(ncallers) - for i := 0; i < ncallers; i++ { - go caller(i) - } - wg.Wait() -} - -func TestClientSetHeader(t *testing.T) { - var gotHeader bool - srv := newTestServer() - httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("test") == "ok" { - gotHeader = true - } - srv.ServeHTTP(w, r) - })) - defer httpsrv.Close() - defer srv.Stop() - - client, err := Dial(httpsrv.URL) - if err != nil { - t.Fatal(err) - } - defer client.Close() - - client.SetHeader("test", "ok") - if _, err := client.SupportedModules(); err != nil { - t.Fatal(err) - } - if !gotHeader { - t.Fatal("client did not set custom header") - } - - //NOTE: this test is removed because we accept invalid content types - // Check that Content-Type can be replaced. - //client.SetHeader("content-type", "application/x-garbage") - //_, err = client.SupportedModules() - //if err == nil { - // t.Fatal("no error for invalid content-type header") - //} else if !strings.Contains(err.Error(), "Unsupported Media Type") { - // t.Fatalf("error is not related to content-type: %q", err) - //} -} - -func TestClientHTTP(t *testing.T) { - server := newTestServer() - defer server.Stop() - - client, hs := httpTestClient(server, "http", nil) - defer hs.Close() - defer client.Close() - - // Launch concurrent requests. - var ( - results = make([]echoResult, 100) - errc = make(chan error, len(results)) - wantResult = echoResult{"a", 1, new(echoArgs)} - ) - defer client.Close() - for i := range results { - i := i - go func() { - errc <- client.Call(nil, &results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args) - }() - } - - // Wait for all of them to complete. - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() - for i := range results { - select { - case err := <-errc: - if err != nil { - t.Fatal(err) - } - case <-timeout.C: - t.Fatalf("timeout (got %d/%d) results)", i+1, len(results)) - } - } - - // Check results. - for i := range results { - if !reflect.DeepEqual(results[i], wantResult) { - t.Errorf("result %d mismatch: got %#v, want %#v", i, results[i], wantResult) - } - } -} - -func TestClientReconnect(t *testing.T) { - startServer := func(addr string) (*Server, net.Listener) { - srv := newTestServer() - l, err := net.Listen("tcp", addr) - if err != nil { - t.Fatal("can't listen:", err) - } - go http.Serve(l, srv.WebsocketHandler([]string{"*"})) - return srv, l - } - - ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) - defer cancel() - - // Start a server and corresponding client. - s1, l1 := startServer("127.0.0.1:0") - client, err := DialContext(ctx, "ws://"+l1.Addr().String()) - if err != nil { - t.Fatal("can't dial", err) - } - defer client.Close() - - // Perform a call. This should work because the server is up. - var resp echoResult - if err := client.Call(ctx, &resp, "test_echo", "", 1, nil); err != nil { - t.Fatal(err) - } - - // Shut down the server and allow for some cool down time so we can listen on the same - // address again. - l1.Close() - s1.Stop() - time.Sleep(2 * time.Second) - - // Try calling again. It shouldn't work. - if err := client.Call(ctx, &resp, "test_echo", "", 2, nil); err == nil { - t.Error("successful call while the server is down") - t.Logf("resp: %#v", resp) - } - - // Start it up again and call again. The connection should be reestablished. - // We spawn multiple calls here to check whether this hangs somehow. - s2, l2 := startServer(l1.Addr().String()) - defer l2.Close() - defer s2.Stop() - - start := make(chan struct{}) - errors := make(chan error, 20) - for i := 0; i < cap(errors); i++ { - go func() { - <-start - var resp echoResult - errors <- client.Call(ctx, &resp, "test_echo", "", 3, nil) - }() - } - close(start) - errcount := 0 - for i := 0; i < cap(errors); i++ { - if err = <-errors; err != nil { - errcount++ - } - } - t.Logf("%d errors, last error: %v", errcount, err) - if errcount > 1 { - t.Errorf("expected one error after disconnect, got %d", errcount) - } -} - -func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { - // Create the HTTP server. - var hs *httptest.Server - switch transport { - case "ws": - hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) - case "http": - hs = httptest.NewUnstartedServer(srv) - default: - panic("unknown HTTP transport: " + transport) - } - // Wrap the listener if required. - if fl != nil { - fl.Listener = hs.Listener - hs.Listener = fl - } - // Connect the client. - hs.Start() - client, err := Dial(transport + "://" + hs.Listener.Addr().String()) - if err != nil { - panic(err) - } - return client, hs -} - -func ipcTestClient(srv *Server, fl *flakeyListener) (*Client, net.Listener) { - // Listen on a random endpoint. - endpoint := fmt.Sprintf("go-ethereum-test-ipc-%d-%d", os.Getpid(), rand.Int63()) - if runtime.GOOS == "windows" { - endpoint = `\\.\pipe\` + endpoint - } else { - endpoint = os.TempDir() + "/" + endpoint - } - l, err := ipcListen(endpoint) - if err != nil { - panic(err) - } - // Connect the listener to the server. - if fl != nil { - fl.Listener = l - l = fl - } - go srv.ServeListener(l) - // Connect the client. - client, err := Dial(endpoint) - if err != nil { - panic(err) - } - return client, l -} - -// flakeyListener kills accepted connections after a random timeout. -type flakeyListener struct { - net.Listener - maxKillTimeout time.Duration - maxAcceptDelay time.Duration -} - -func (l *flakeyListener) Accept() (net.Conn, error) { - delay := time.Duration(rand.Int63n(int64(l.maxAcceptDelay))) - time.Sleep(delay) - - c, err := l.Listener.Accept() - if err == nil { - timeout := time.Duration(rand.Int63n(int64(l.maxKillTimeout))) - time.AfterFunc(timeout, func() { - log.Debug().Msg(fmt.Sprintf("killing conn %v after %v", c.LocalAddr(), timeout)) - c.Close() - }) - } - return c, err + }) } diff --git a/pkg/codec/codecs/http/codec.go b/pkg/codec/codecs/http/codec.go index 62fba6ccf00d5dc7cdd29a23bbe2a473b33f6bd2..4ebad4138d587c3d27790af8e55b7fca6d9e56f7 100644 --- a/pkg/codec/codecs/http/codec.go +++ b/pkg/codec/codecs/http/codec.go @@ -1,11 +1,14 @@ package http import ( + "bufio" "context" "encoding/base64" "encoding/json" "errors" + "fmt" "io" + "mime" "net/http" "net/url" @@ -17,49 +20,58 @@ type Codec struct { ctx context.Context cn func() - r *http.Request - w http.ResponseWriter - msgs chan json.RawMessage - errs chan error + r *http.Request + w http.ResponseWriter + wr *bufio.Writer + msgs chan json.RawMessage + errCh chan httpError + + i codec.PeerInfo +} + +type httpError struct { + code int + err error } -func NewCodec(r *http.Request, w http.ResponseWriter) *Codec { - ctx, cn := context.WithCancel(r.Context()) +func NewCodec(w http.ResponseWriter, r *http.Request) *Codec { c := &Codec{ - ctx: ctx, - cn: cn, - r: r, - w: w, - msgs: make(chan json.RawMessage, 1), - errs: make(chan error, 1), + r: r, + w: w, + wr: bufio.NewWriter(w), + msgs: make(chan json.RawMessage, 1), + errCh: make(chan httpError, 1), } - go c.doRead() + ctx := r.Context() + c.ctx, c.cn = context.WithCancel(ctx) + c.peerInfo() + c.doRead() return c } - -// gets the peer info -func (c *Codec) PeerInfo() codec.PeerInfo { - ci := codec.PeerInfo{ - Transport: "http", - RemoteAddr: c.r.RemoteAddr, - HTTP: codec.HttpInfo{ - Version: c.r.Proto, - UserAgent: c.r.UserAgent(), - Host: c.r.Host, - Headers: c.r.Header.Clone(), - }, +func (c *Codec) peerInfo() { + c.i.Transport = "http" + c.i.RemoteAddr = c.r.RemoteAddr + c.i.HTTP = codec.HttpInfo{ + Version: c.r.Proto, + UserAgent: c.r.UserAgent(), + Host: c.r.Host, + Headers: c.r.Header.Clone(), } - ci.HTTP.Origin = c.r.Header.Get("X-Real-Ip") - if ci.HTTP.Origin == "" { - ci.HTTP.Origin = c.r.Header.Get("X-Forwarded-For") + c.i.HTTP.Origin = c.r.Header.Get("X-Real-Ip") + if c.i.HTTP.Origin == "" { + c.i.HTTP.Origin = c.r.Header.Get("X-Forwarded-For") } - if ci.HTTP.Origin == "" { - ci.HTTP.Origin = c.r.Header.Get("Origin") + if c.i.HTTP.Origin == "" { + c.i.HTTP.Origin = c.r.Header.Get("Origin") } - if ci.HTTP.Origin == "" { - ci.HTTP.Origin = c.r.RemoteAddr + if c.i.HTTP.Origin == "" { + c.i.HTTP.Origin = c.r.RemoteAddr } - return ci +} + +// gets the peer info +func (c *Codec) PeerInfo() codec.PeerInfo { + return c.i } func (r *Codec) doReadGet() (msgs json.RawMessage, err error) { @@ -77,34 +89,61 @@ func (r *Codec) doReadGet() (msgs json.RawMessage, err error) { return req.MarshalJSON() } -var ErrInvalidContentType = errors.New("invalid content type") - -func (c *Codec) doRead() { - contentMatches := true - types := c.r.Header.Values("content-type") - for _, v := range types { - // TODO: check content type - _ = v +// validateRequest returns a non-zero response code and error message if the +// request is invalid. +func ValidateRequest(r *http.Request) (int, error) { + if r.Method == http.MethodPut || r.Method == http.MethodDelete { + return http.StatusMethodNotAllowed, errors.New("method not allowed") } - if !contentMatches { - c.errs <- ErrInvalidContentType - return + if r.ContentLength > maxRequestContentLength { + err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength) + return http.StatusRequestEntityTooLarge, err } - var data json.RawMessage - var err error - // TODO: implement eventsource - switch c.r.Method { - case http.MethodGet: - data, err = c.doReadGet() - return - case http.MethodPost: - data, err = io.ReadAll(c.r.Body) + // Allow OPTIONS (regardless of content-type) + if r.Method == http.MethodOptions { + return 0, nil + } + // Check content-type + if mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")); err == nil { + for _, accepted := range acceptedContentTypes { + if accepted == mt { + return 0, nil + } + } } + // Invalid content-type ignored for now + return 0, nil + //err := fmt.Errorf("invalid content type, only %s is supported", contentType) + //return http.StatusUnsupportedMediaType, err +} + +func (c *Codec) doRead() { + code, err := ValidateRequest(c.r) if err != nil { - c.errs <- err + c.errCh <- httpError{ + code: code, + err: err, + } return } - c.msgs <- data + go func() { + var data json.RawMessage + // TODO: implement eventsource + switch c.r.Method { + case http.MethodGet: + data, err = c.doReadGet() + case http.MethodPost: + data, err = io.ReadAll(c.r.Body) + } + if err != nil { + c.errCh <- httpError{ + code: http.StatusInternalServerError, + err: err, + } + return + } + c.msgs <- data + }() } // json.RawMessage can be an array of requests. if it is, then it is a batch request @@ -112,8 +151,9 @@ func (c *Codec) ReadBatch(ctx context.Context) (msgs json.RawMessage, err error) select { case ans := <-c.msgs: return ans, nil - case err := <-c.errs: - return nil, err + case err := <-c.errCh: + http.Error(c.w, err.err.Error(), err.code) + return nil, err.err case <-ctx.Done(): return nil, ctx.Err() case <-c.ctx.Done(): @@ -128,7 +168,16 @@ func (c *Codec) Close() error { } func (c *Codec) Write(p []byte) (n int, err error) { - return c.w.Write(p) + return c.wr.Write(p) +} + +func (c *Codec) Flush() (err error) { + err = c.wr.Flush() + if err != nil { + return err + } + c.cn() + return } // Closed returns a channel which is closed when the connection is closed. @@ -138,5 +187,5 @@ func (c *Codec) Closed() <-chan struct{} { // RemoteAddr returns the peer address of the connection. func (c *Codec) RemoteAddr() string { - return "" + return c.r.RemoteAddr } diff --git a/pkg/codec/codecs/http/const.go b/pkg/codec/codecs/http/const.go new file mode 100644 index 0000000000000000000000000000000000000000..bb84f2cd8b6674f94844952ad842090ffa22429e --- /dev/null +++ b/pkg/codec/codecs/http/const.go @@ -0,0 +1,18 @@ +package http + +import "errors" + +const ( + // NOTE: if you change this, you will have to change the thing in jrpctest... its what its for now until tests get refactored + maxRequestContentLength = 1024 * 1024 * 5 + contentType = "application/json" +) + +// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 +var acceptedContentTypes = []string{ + // https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 + contentType, "application/json-rpc", "application/jsonrequest", + // these are added because they make sense, fight me! + "application/jsonrpc2", "application/json-rpc2", "application/jrpc", +} +var ErrInvalidContentType = errors.New("invalid content type") diff --git a/pkg/codec/codecs/http/handler.go b/pkg/codec/codecs/http/handler.go new file mode 100644 index 0000000000000000000000000000000000000000..c4f74a549b8aa79ea1f1024a031f803a22b439e0 --- /dev/null +++ b/pkg/codec/codecs/http/handler.go @@ -0,0 +1,20 @@ +package http + +import ( + "net/http" + + "gfx.cafe/open/jrpc" +) + +type Server struct { + Server *jrpc.Server +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.Server == nil { + http.Error(w, "no server set", http.StatusInternalServerError) + return + } + c := NewCodec(w, r) + s.Server.ServeCodec(r.Context(), c) +} diff --git a/pkg/codec/codecs/http/http_test.go b/pkg/codec/codecs/http/http_test.go index 41fbb6dcb25e1a2e8419736760302a3db32ec3df..99aba33b51c596a730eb315350917972d0b87061 100644 --- a/pkg/codec/codecs/http/http_test.go +++ b/pkg/codec/codecs/http/http_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 The go-ethereum Authors +// Copyright 2018 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -21,8 +21,15 @@ import ( "net/http/httptest" "strings" "testing" + + "gfx.cafe/open/jrpc" + "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/open/jrpc/pkg/jmux" + "gfx.cafe/open/jrpc/pkg/jrpctest" ) +const respLength = maxRequestContentLength * 3 + func confirmStatusCode(t *testing.T, got, want int) { t.Helper() if got == want { @@ -42,7 +49,7 @@ func confirmRequestValidationCode(t *testing.T, method, contentType, body string if len(contentType) > 0 { request.Header.Set("Content-Type", contentType) } - code, err := validateRequest(request) + code, err := ValidateRequest(request) if code == 0 { if err != nil { t.Errorf("validation: got error %v, expected nil", err) @@ -79,8 +86,9 @@ func TestHTTPErrorResponseWithValidRequest(t *testing.T) { func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { t.Helper() - s := Server{} - ts := httptest.NewServer(&s) + s := jrpc.NewServer(jmux.NewMux()) + defer s.Stop() + ts := httptest.NewServer(&Server{Server: s}) defer ts.Close() request, err := http.NewRequest(method, ts.URL, strings.NewReader(body)) @@ -103,12 +111,9 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) { // This checks that maxRequestContentLength is not applied to the response of a request. func TestHTTPRespBodyUnlimited(t *testing.T) { - const respLength = maxRequestContentLength * 3 - - s := NewServer() + s := jrpctest.NewServer() defer s.Stop() - s.Router().RegisterStruct("test", largeRespService{respLength}) - ts := httptest.NewServer(s) + ts := httptest.NewServer(&Server{Server: s}) defer ts.Close() c, err := DialHTTP(ts.URL) @@ -118,7 +123,7 @@ func TestHTTPRespBodyUnlimited(t *testing.T) { defer c.Close() var r string - if err := c.Call(nil, &r, "test_largeResp"); err != nil { + if err := c.Do(nil, &r, "large_largeResp", nil); err != nil { t.Fatal(err) } if len(r) != respLength { @@ -140,12 +145,12 @@ func TestHTTPErrorResponse(t *testing.T) { } var r string - err = c.Call(nil, &r, "test_method") + err = c.Do(nil, &r, "test_method", nil) if err == nil { t.Fatal("error was expected") } - httpErr, ok := err.(HTTPError) + httpErr, ok := err.(*codec.HTTPError) if !ok { t.Fatalf("unexpected error type %T", err) } @@ -166,12 +171,12 @@ func TestHTTPErrorResponse(t *testing.T) { } func TestHTTPPeerInfo(t *testing.T) { - s := newTestServer() + s := jrpctest.NewServer() defer s.Stop() - ts := httptest.NewServer(s) + ts := httptest.NewServer(&Server{Server: s}) defer ts.Close() - c, err := Dial(ts.URL) + c, err := DialHTTP(ts.URL) if err != nil { t.Fatal(err) } @@ -179,8 +184,8 @@ func TestHTTPPeerInfo(t *testing.T) { c.SetHeader("x-forwarded-for", "origin.example.com") // Request peer information. - var info PeerInfo - if err := c.Call(nil, &info, "test_peerInfo"); err != nil { + var info codec.PeerInfo + if err := c.Do(nil, &info, "test_peerInfo", nil); err != nil { t.Fatal(err) } diff --git a/pkg/codec/codecs/inproc/inproc.go b/pkg/codec/codecs/inproc/inproc.go index 82f0c0830c95f2406ce7025fd78370cca9f09abe..53b60dea772b39e7766da3d65398e38797be87cf 100644 --- a/pkg/codec/codecs/inproc/inproc.go +++ b/pkg/codec/codecs/inproc/inproc.go @@ -4,8 +4,9 @@ import ( "bufio" "context" "encoding/json" - "gfx.cafe/open/jrpc/pkg/codec" "io" + + "gfx.cafe/open/jrpc/pkg/codec" ) type Codec struct { @@ -13,7 +14,7 @@ type Codec struct { cn func() rd io.Reader - wr io.Writer + wr *bufio.Writer msgs chan json.RawMessage } @@ -24,7 +25,7 @@ func NewCodec() *Codec { ctx: ctx, cn: cn, rd: bufio.NewReader(rd), - wr: wr, + wr: bufio.NewWriter(wr), msgs: make(chan json.RawMessage, 8), } } @@ -60,6 +61,10 @@ func (c *Codec) Write(p []byte) (n int, err error) { return c.wr.Write(p) } +func (c *Codec) Flush() (err error) { + return c.wr.Flush() +} + // Closed returns a channel which is closed when the connection is closed. func (c *Codec) Closed() <-chan struct{} { return c.ctx.Done() diff --git a/pkg/codec/peer.go b/pkg/codec/peer.go index 9dba5dfa603200ae4ed20baa5b211d328c99001d..32ab3eeaf3b89ed72e15fd4dd59eaefa29e60033 100644 --- a/pkg/codec/peer.go +++ b/pkg/codec/peer.go @@ -1,6 +1,8 @@ package codec -import "net/http" +import ( + "net/http" +) type PeerInfo struct { // Transport is name of the protocol used by the client. diff --git a/pkg/codec/transport.go b/pkg/codec/transport.go index e683473aa6c9b71fbb7a156ba07de8613c08280d..f99daa32c82576492abd7682754988249c2ccadf 100644 --- a/pkg/codec/transport.go +++ b/pkg/codec/transport.go @@ -27,6 +27,9 @@ type Reader interface { type Writer interface { // write json blob to stream io.Writer + // Flush flushes the writer to the stream between messages + Flush() error + // Closed returns a channel which is closed when the connection is closed. Closed() <-chan struct{} // RemoteAddr returns the peer address of the connection. diff --git a/pkg/jrpctest/convert_test.go b/pkg/jrpctest/convert_test.go new file mode 100644 index 0000000000000000000000000000000000000000..391d542aae8924e0e8e1c2a5c980cfd42ada0d78 --- /dev/null +++ b/pkg/jrpctest/convert_test.go @@ -0,0 +1,448 @@ +package jrpctest_test + +import ( + "context" + "fmt" + "math/rand" + "net" + "net/http" + "net/http/httptest" + "os" + "reflect" + "runtime" + "sync" + "testing" + "time" + + "gfx.cafe/open/jrpc/pkg/jrpctest" + "github.com/anacrolix/log" + "github.com/davecgh/go-spew/spew" +) + +func TestClientRequest(t *testing.T) { + s := jrpctest.NewServer() + defer s.Stop() + ts := httptest.NewServer(&Server{Server: s}) + defer ts.Close() + +} + +func TestClientResponseType(t *testing.T) { + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Call(nil, nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Errorf("Passing nil as result should be fine, but got an error: %v", err) + } + var resultVar echoResult + // Note: passing the var, not a ref + err := client.Call(nil, resultVar, "test_echo", "hello", 10, &echoArgs{"world"}) + if err == nil { + t.Error("Passing a var as result should be an error") + } +} + +// This test checks that server-returned errors with code and data come out of Client.Call. +func TestClientErrorData(t *testing.T) { + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + var resp any + err := client.Call(nil, &resp, "test_returnError") + if err == nil { + t.Fatal("expected error") + } + + // Check code. + if e, ok := err.(Error); !ok { + t.Fatalf("client did not return rpc.Error, got %#v", e) + } else if e.ErrorCode() != (testError{}.ErrorCode()) { + t.Fatalf("wrong error code %d, want %d", e.ErrorCode(), testError{}.ErrorCode()) + } + // Check data. + if e, ok := err.(DataError); !ok { + t.Fatalf("client did not return rpc.DataError, got %#v", e) + } else if e.ErrorData() != (testError{}.ErrorData()) { + t.Fatalf("wrong error data %#v, want %#v", e.ErrorData(), testError{}.ErrorData()) + } +} + +func TestClientBatchRequest(t *testing.T) { + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + batch := []BatchElem{ + { + Method: "test_echo", + Args: []any{"hello", 10, &echoArgs{"world"}}, + Result: new(echoResult), + }, + { + Method: "test_echo", + Args: []any{"hello2", 11, &echoArgs{"world"}}, + Result: new(echoResult), + }, + { + Method: "no_such_method", + Args: []any{1, 2, 3}, + Result: new(int), + }, + } + if err := client.BatchCall(nil, batch...); err != nil { + t.Fatal(err) + } + wantResult := []BatchElem{ + { + Method: "test_echo", + Args: []any{"hello", 10, &echoArgs{"world"}}, + Result: &echoResult{"hello", 10, &echoArgs{"world"}}, + }, + { + Method: "test_echo", + Args: []any{"hello2", 11, &echoArgs{"world"}}, + Result: &echoResult{"hello2", 11, &echoArgs{"world"}}, + }, + { + Method: "no_such_method", + Args: []any{1, 2, 3}, + Result: new(int), + Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"}, + }, + } + if !reflect.DeepEqual(batch, wantResult) { + t.Errorf("batch results mismatch:\ngot %swant %s", spew.Sdump(batch), spew.Sdump(wantResult)) + } +} + +func TestClientNotify(t *testing.T) { + server := newTestServer() + defer server.Stop() + client := DialInProc(server) + defer client.Close() + + if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { + t.Fatal(err) + } +} + +// func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } +func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } +func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } +func TestClientCancelIPC(t *testing.T) { testClientCancel("ipc", t) } + +// This test checks that requests made through Call can be canceled by canceling +// the context. +func testClientCancel(transport string, t *testing.T) { + // These tests take a lot of time, run them all at once. + // You probably want to run with -parallel 1 or comment out + // the call to t.Parallel if you enable the logging. + t.Parallel() + + server := newTestServer() + defer server.Stop() + + // What we want to achieve is that the context gets canceled + // at various stages of request processing. The interesting cases + // are: + // - cancel during dial + // - cancel while performing a HTTP request + // - cancel while waiting for a response + // + // To trigger those, the times are chosen such that connections + // are killed within the deadline for every other call (maxKillTimeout + // is 2x maxCancelTimeout). + // + // Once a connection is dead, there is a fair chance it won't connect + // successfully because the accept is delayed by 1s. + maxContextCancelTimeout := 300 * time.Millisecond + fl := &flakeyListener{ + maxAcceptDelay: 1 * time.Second, + maxKillTimeout: 600 * time.Millisecond, + } + + var client *Client + switch transport { + case "ws", "http": + c, hs := httpTestClient(server, transport, fl) + defer hs.Close() + client = c + case "ipc": + c, l := ipcTestClient(server, fl) + defer l.Close() + client = c + default: + panic("unknown transport: " + transport) + } + + // The actual test starts here. + var ( + wg sync.WaitGroup + nreqs = 10 + ncallers = 10 + ) + caller := func(index int) { + defer wg.Done() + for i := 0; i < nreqs; i++ { + var ( + ctx context.Context + cancel func() + timeout = time.Duration(rand.Int63n(int64(maxContextCancelTimeout))) + ) + if index < ncallers/2 { + // For half of the callers, create a context without deadline + // and cancel it later. + ctx, cancel = context.WithCancel(context.Background()) + time.AfterFunc(timeout, cancel) + } else { + // For the other half, create a context with a deadline instead. This is + // different because the context deadline is used to set the socket write + // deadline. + ctx, cancel = context.WithTimeout(context.Background(), timeout) + } + + // Now perform a call with the context. + // The key thing here is that no call will ever complete successfully. + err := client.Call(ctx, nil, "test_block") + switch { + case err == nil: + _, hasDeadline := ctx.Deadline() + t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) + // default: + // t.Logf("got expected error with %v wait time: %v", timeout, err) + } + cancel() + } + } + wg.Add(ncallers) + for i := 0; i < ncallers; i++ { + go caller(i) + } + wg.Wait() +} + +func TestClientSetHeader(t *testing.T) { + var gotHeader bool + srv := newTestServer() + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("test") == "ok" { + gotHeader = true + } + srv.ServeHTTP(w, r) + })) + defer httpsrv.Close() + defer srv.Stop() + + client, err := Dial(httpsrv.URL) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + client.SetHeader("test", "ok") + if _, err := client.SupportedModules(); err != nil { + t.Fatal(err) + } + if !gotHeader { + t.Fatal("client did not set custom header") + } + + //NOTE: this test is removed because we accept invalid content types + // Check that Content-Type can be replaced. + //client.SetHeader("content-type", "application/x-garbage") + //_, err = client.SupportedModules() + //if err == nil { + // t.Fatal("no error for invalid content-type header") + //} else if !strings.Contains(err.Error(), "Unsupported Media Type") { + // t.Fatalf("error is not related to content-type: %q", err) + //} +} + +func TestClientHTTP(t *testing.T) { + server := newTestServer() + defer server.Stop() + + client, hs := httpTestClient(server, "http", nil) + defer hs.Close() + defer client.Close() + + // Launch concurrent requests. + var ( + results = make([]echoResult, 100) + errc = make(chan error, len(results)) + wantResult = echoResult{"a", 1, new(echoArgs)} + ) + defer client.Close() + for i := range results { + i := i + go func() { + errc <- client.Call(nil, &results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args) + }() + } + + // Wait for all of them to complete. + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + for i := range results { + select { + case err := <-errc: + if err != nil { + t.Fatal(err) + } + case <-timeout.C: + t.Fatalf("timeout (got %d/%d) results)", i+1, len(results)) + } + } + + // Check results. + for i := range results { + if !reflect.DeepEqual(results[i], wantResult) { + t.Errorf("result %d mismatch: got %#v, want %#v", i, results[i], wantResult) + } + } +} + +func TestClientReconnect(t *testing.T) { + startServer := func(addr string) (*Server, net.Listener) { + srv := newTestServer() + l, err := net.Listen("tcp", addr) + if err != nil { + t.Fatal("can't listen:", err) + } + go http.Serve(l, srv.WebsocketHandler([]string{"*"})) + return srv, l + } + + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + + // Start a server and corresponding client. + s1, l1 := startServer("127.0.0.1:0") + client, err := DialContext(ctx, "ws://"+l1.Addr().String()) + if err != nil { + t.Fatal("can't dial", err) + } + defer client.Close() + + // Perform a call. This should work because the server is up. + var resp echoResult + if err := client.Call(ctx, &resp, "test_echo", "", 1, nil); err != nil { + t.Fatal(err) + } + + // Shut down the server and allow for some cool down time so we can listen on the same + // address again. + l1.Close() + s1.Stop() + time.Sleep(2 * time.Second) + + // Try calling again. It shouldn't work. + if err := client.Call(ctx, &resp, "test_echo", "", 2, nil); err == nil { + t.Error("successful call while the server is down") + t.Logf("resp: %#v", resp) + } + + // Start it up again and call again. The connection should be reestablished. + // We spawn multiple calls here to check whether this hangs somehow. + s2, l2 := startServer(l1.Addr().String()) + defer l2.Close() + defer s2.Stop() + + start := make(chan struct{}) + errors := make(chan error, 20) + for i := 0; i < cap(errors); i++ { + go func() { + <-start + var resp echoResult + errors <- client.Call(ctx, &resp, "test_echo", "", 3, nil) + }() + } + close(start) + errcount := 0 + for i := 0; i < cap(errors); i++ { + if err = <-errors; err != nil { + errcount++ + } + } + t.Logf("%d errors, last error: %v", errcount, err) + if errcount > 1 { + t.Errorf("expected one error after disconnect, got %d", errcount) + } +} + +func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { + // Create the HTTP server. + var hs *httptest.Server + switch transport { + case "ws": + hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) + case "http": + hs = httptest.NewUnstartedServer(srv) + default: + panic("unknown HTTP transport: " + transport) + } + // Wrap the listener if required. + if fl != nil { + fl.Listener = hs.Listener + hs.Listener = fl + } + // Connect the client. + hs.Start() + client, err := Dial(transport + "://" + hs.Listener.Addr().String()) + if err != nil { + panic(err) + } + return client, hs +} + +func ipcTestClient(srv *Server, fl *flakeyListener) (*Client, net.Listener) { + // Listen on a random endpoint. + endpoint := fmt.Sprintf("go-ethereum-test-ipc-%d-%d", os.Getpid(), rand.Int63()) + if runtime.GOOS == "windows" { + endpoint = `\\.\pipe\` + endpoint + } else { + endpoint = os.TempDir() + "/" + endpoint + } + l, err := ipcListen(endpoint) + if err != nil { + panic(err) + } + // Connect the listener to the server. + if fl != nil { + fl.Listener = l + l = fl + } + go srv.ServeListener(l) + // Connect the client. + client, err := Dial(endpoint) + if err != nil { + panic(err) + } + return client, l +} + +// flakeyListener kills accepted connections after a random timeout. +type flakeyListener struct { + net.Listener + maxKillTimeout time.Duration + maxAcceptDelay time.Duration +} + +func (l *flakeyListener) Accept() (net.Conn, error) { + delay := time.Duration(rand.Int63n(int64(l.maxAcceptDelay))) + time.Sleep(delay) + + c, err := l.Listener.Accept() + if err == nil { + timeout := time.Duration(rand.Int63n(int64(l.maxKillTimeout))) + time.AfterFunc(timeout, func() { + log.Debug().Msg(fmt.Sprintf("killing conn %v after %v", c.LocalAddr(), timeout)) + c.Close() + }) + } + return c, err +} diff --git a/pkg/jrpctest/server.go b/pkg/jrpctest/server.go index 690923da7fa919ab2f156c1d167fdd090bf01471..6eb377cba88caf065c967d26520f086d99bfb21b 100644 --- a/pkg/jrpctest/server.go +++ b/pkg/jrpctest/server.go @@ -1,33 +1,80 @@ package jrpctest import ( + "strings" + "gfx.cafe/open/jrpc" "gfx.cafe/open/jrpc/pkg/jmux" ) -func NewTestServer() *jrpc.Server { +func NewServer() *jrpc.Server { + server := jrpc.NewServer(NewRouter()) + return server +} +func NewRouter() *jmux.Mux { mux := jmux.NewRouter() - server := jrpc.NewServer(mux) - mux.HandleFunc("testservice_subscribe", func(w jrpc.ResponseWriter, r *jrpc.Request) { - sub, err := jrpc.UpgradeToSubscription(w, r) - w.Send(sub, err) - if err != nil { - return - } - idx := 0 - for { - err := w.Notify(idx) - if err != nil { - return - } - idx = idx + 1 - } - }) + //mux.HandleFunc("testservice_subscribe", func(w jrpc.ResponseWriter, r *jrpc.Request) { + // sub, err := jrpc.UpgradeToSubscription(w, r) + // w.Send(sub, err) + // if err != nil { + // return + // } + // idx := 0 + // for { + // err := w.Notify(idx) + // if err != nil { + // return + // } + // idx = idx + 1 + // } + //}) if err := mux.RegisterStruct("test", new(testService)); err != nil { panic(err) } if err := mux.RegisterStruct("nftest", new(notificationTestService)); err != nil { panic(err) } - return server + + if err := mux.RegisterStruct("large", largeRespService{1024 * 1024 * 5 * 3}); err != nil { + panic(err) + } + return mux +} +func NewRouterWithMaxSize(size int) *jmux.Mux { + mux := jmux.NewRouter() + //mux.HandleFunc("testservice_subscribe", func(w jrpc.ResponseWriter, r *jrpc.Request) { + // sub, err := jrpc.UpgradeToSubscription(w, r) + // w.Send(sub, err) + // if err != nil { + // return + // } + // idx := 0 + // for { + // err := w.Notify(idx) + // if err != nil { + // return + // } + // idx = idx + 1 + // } + //}) + if err := mux.RegisterStruct("test", new(testService)); err != nil { + panic(err) + } + if err := mux.RegisterStruct("nftest", new(notificationTestService)); err != nil { + panic(err) + } + + if err := mux.RegisterStruct("large", largeRespService{size}); err != nil { + panic(err) + } + return mux +} + +// largeRespService generates arbitrary-size JSON responses. +type largeRespService struct { + length int +} + +func (x largeRespService) LargeResp() string { + return strings.Repeat("x", x.length) } diff --git a/pkg/jrpctest/services.go b/pkg/jrpctest/services.go index a9ec74e7906721b78defd24f8f1ebc5113cbbd91..653a532b485611620f91983416ddf989bbcf9a20 100644 --- a/pkg/jrpctest/services.go +++ b/pkg/jrpctest/services.go @@ -3,23 +3,24 @@ package jrpctest import ( "context" "errors" - "gfx.cafe/open/jrpc/pkg/codec" "strings" "time" + "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/open/jrpc" ) type testService struct{} -type echoArgs struct { +type EchoArgs struct { S string } type EchoResult struct { String string Int int - Args *echoArgs + Args *EchoArgs } type testError struct{} @@ -34,11 +35,11 @@ func (s *testService) EchoAny(n any) any { return n } -func (s *testService) Echo(str string, i int, args *echoArgs) EchoResult { +func (s *testService) Echo(str string, i int, args *EchoArgs) EchoResult { return EchoResult{str, i, args} } -func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *echoArgs) EchoResult { +func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *EchoArgs) EchoResult { return EchoResult{str, i, args} } @@ -77,24 +78,24 @@ func (s *testService) ReturnError() error { } func (s *testService) CallMeBack(ctx context.Context, method string, args []any) (any, error) { - c, ok := jrpc.ClientFromContext(ctx) + c, ok := jrpc.ConnFromContext(ctx) if !ok { return nil, errors.New("no client") } var result any - err := c.Call(nil, &result, method, args...) + err := c.Do(nil, &result, method, args) return result, err } func (s *testService) CallMeBackLater(ctx context.Context, method string, args []any) error { - c, ok := jrpc.ClientFromContext(ctx) + c, ok := jrpc.ConnFromContext(ctx) if !ok { return errors.New("no client") } go func() { <-ctx.Done() var result any - c.Call(nil, &result, method, args...) + c.Do(nil, &result, method, args) }() return nil } diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go new file mode 100644 index 0000000000000000000000000000000000000000..d709f1d68e789cf3ffb970a8632c45ce204203e4 --- /dev/null +++ b/pkg/jrpctest/suites.go @@ -0,0 +1,43 @@ +package jrpctest + +import ( + "reflect" + "testing" + + "gfx.cafe/open/jrpc" + "github.com/stretchr/testify/require" +) + +type ClientMaker func() jrpc.Conn +type ServerMaker func() (*jrpc.Server, ClientMaker, func()) + +type BasicTestSuiteArgs struct { + ServerMaker ServerMaker +} + +type TestContext func(t *testing.T, server *jrpc.Server, client jrpc.Conn) + +func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { + var executeTest = func(t *testing.T, c TestContext) { + server, dialer, cn := args.ServerMaker() + defer cn() + defer server.Stop() + client := dialer() + defer client.Close() + c(t, server, client) + } + + var makeTest = func(name string, fm func(t *testing.T, server *jrpc.Server, client jrpc.Conn)) { + t.Run(name, func(t *testing.T) { + executeTest(t, fm) + }) + } + makeTest("ClientRequest", func(t *testing.T, server *jrpc.Server, client jrpc.Conn) { + var resp EchoResult + err := client.Do(nil, &resp, "test_echo", []any{"hello", 10, &EchoArgs{"world"}}) + require.NoError(t, err) + if !reflect.DeepEqual(resp, EchoResult{"hello", 10, &EchoArgs{"world"}}) { + t.Errorf("incorrect result %#v", resp) + } + }) +} diff --git a/request.go b/request.go index 2f11ffe57663582ea66f1db8a9575161a75feb15..315f6ffe6dee0abb324058cf77fcf6b70fd8ff63 100644 --- a/request.go +++ b/request.go @@ -31,6 +31,9 @@ type RequestMarshaling struct { } func NewRequestInt(ctx context.Context, id int, method string, params any) *Request { + if ctx == nil { + ctx = context.Background() + } r := &Request{ctx: ctx} pms, _ := json.Marshal(params) r.ID = codec2.NewNumberIDPtr(int64(id)) @@ -40,6 +43,9 @@ func NewRequestInt(ctx context.Context, id int, method string, params any) *Requ } func NewRequest(ctx context.Context, id string, method string, params any) *Request { + if ctx == nil { + ctx = context.Background() + } r := &Request{ctx: ctx} pms, _ := json.Marshal(params) r.ID = codec2.NewStringIDPtr(id) @@ -48,6 +54,18 @@ func NewRequest(ctx context.Context, id string, method string, params any) *Requ return r } +func NewNotification(ctx context.Context, method string, params any) *Request { + if ctx == nil { + ctx = context.Background() + } + r := &Request{ctx: ctx} + pms, _ := json.Marshal(params) + r.ID = nil + r.Method = method + r.Params = pms + return r +} + func (r *Request) makeError(err error) *codec2.Message { m := r.Msg() return m.ErrorResponse(err) diff --git a/server.go b/server.go index a69393de86ec41eae9850000430c7ca2b89be8e2..88ec9b1943590778ecfaa255d6c89869d7be866d 100644 --- a/server.go +++ b/server.go @@ -2,12 +2,13 @@ package jrpc import ( "context" - codec2 "gfx.cafe/open/jrpc/pkg/codec" "io" "net/http" "sync" "sync/atomic" + "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/util/go/bufpool" mapset "github.com/deckarep/golang-set" @@ -24,7 +25,7 @@ type Server struct { } type Tracing struct { - ErrorLogger func(remote codec2.ReaderWriter, err error) + ErrorLogger func(remote codec.ReaderWriter, err error) } // NewServer creates a new server instance with no registered handlers. @@ -39,7 +40,7 @@ func NewServer(r Handler) *Server { return server } -func (s *Server) printError(remote codec2.ReaderWriter, err error) { +func (s *Server) printError(remote codec.ReaderWriter, err error) { if err != nil { return } @@ -51,7 +52,7 @@ func (s *Server) printError(remote codec2.ReaderWriter, err error) { // ServeCodec reads incoming requests from codec, calls the appropriate callback and writes // the response back using the given codec. It will block until the codec is closed or the // server is stopped. In either case the codec is closed. -func (s *Server) ServeCodec(pctx context.Context, remote codec2.ReaderWriter) { +func (s *Server) ServeCodec(pctx context.Context, remote codec.ReaderWriter) { defer remote.Close() // Don't serve if server is stopped. @@ -70,6 +71,7 @@ func (s *Server) ServeCodec(pctx context.Context, remote codec2.ReaderWriter) { ctx, cn := context.WithCancel(pctx) defer cn() + ctx = ContextWithPeerInfo(ctx, remote.PeerInfo()) go func() { defer cn() err := responder.run(ctx) @@ -93,10 +95,11 @@ func (s *Server) ServeCodec(pctx context.Context, remote codec2.ReaderWriter) { for { msgs, err := remote.ReadBatch(ctx) if err != nil { + remote.Flush() s.printError(remote, err) return } - msg, batch := codec2.ParseMessage(msgs) + msg, batch := codec.ParseMessage(msgs) env := &callEnv{ batch: batch, } @@ -138,7 +141,7 @@ func (s *Server) ServeCodec(pctx context.Context, remote codec2.ReaderWriter) { type callResponder struct { toSend chan *callEnv toNotify chan *notifyEnv - remote codec2.ReaderWriter + remote codec.ReaderWriter } func (c *callResponder) run(ctx context.Context) error { @@ -157,6 +160,9 @@ func (c *callResponder) run(ctx context.Context) error { return err } } + if c.remote != nil { + c.remote.Flush() + } } } func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error { @@ -172,7 +178,7 @@ func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error { err := env.dat(buf) if err != nil { enc.FieldStart("error") - err := codec2.EncodeError(enc, err) + err := codec.EncodeError(enc, err) if err != nil { return err } @@ -202,7 +208,6 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { if v.msg.ID == nil { continue } - buf.Reset() enc.ObjStart() enc.FieldStart("jsonrpc") enc.Str("2.0") @@ -210,17 +215,18 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) error { enc.Raw(v.msg.ID.RawMessage()) err := v.err if err == nil && v.dat != nil { + buf.Reset() err = v.dat(buf) - if err != nil { + if err == nil { enc.FieldStart("result") enc.Raw(buf.Bytes()) } } else { - err = codec2.NewMethodNotFoundError(v.msg.Method) + err = codec.NewMethodNotFoundError(v.msg.Method) } if err != nil { enc.FieldStart("error") - err := codec2.EncodeError(enc, err) + err := codec.EncodeError(enc, err) if err != nil { return err } @@ -248,7 +254,7 @@ type notifyEnv struct { } type callRespWriter struct { - msg *codec2.Message + msg *codec.Message dat func(io.Writer) error err error skip bool @@ -291,7 +297,7 @@ func (c *callRespWriter) Notify(v any) error { func (s *Server) Stop() { if atomic.CompareAndSwapInt32(&s.run, 1, 0) { s.codecs.Each(func(c any) bool { - c.(codec2.ReaderWriter).Close() + c.(codec.ReaderWriter).Close() return true }) } @@ -303,7 +309,10 @@ type peerInfoContextKey struct{} // Use this with the context passed to RPC method handler functions. // // The zero value is returned if no connection info is present in ctx. -func PeerInfoFromContext(ctx context.Context) codec2.PeerInfo { - info, _ := ctx.Value(peerInfoContextKey{}).(codec2.PeerInfo) +func PeerInfoFromContext(ctx context.Context) codec.PeerInfo { + info, _ := ctx.Value(peerInfoContextKey{}).(codec.PeerInfo) return info } +func ContextWithPeerInfo(ctx context.Context, c codec.PeerInfo) context.Context { + return context.WithValue(ctx, peerInfoContextKey{}, c) +}