diff --git a/pkg/jsonrpc/errors.go b/pkg/jsonrpc/errors.go index 81f7e2391b3167a3385bd92ec85354ada3c1714a..5999db2badea2e6371fcea82c42d2669d99ff85e 100644 --- a/pkg/jsonrpc/errors.go +++ b/pkg/jsonrpc/errors.go @@ -8,8 +8,15 @@ import ( "github.com/go-faster/jx" ) -// Error types defined below are the built-in JSON-RPC errors. +var ( + ErrIllegalExtraField = errors.New("invalid extra field") + ErrSendAlreadyCalled = errors.New("send already called") + ErrHijackAlreadyCalled = errors.New("already hijacked") + ErrCantSendNotification = errors.New("can't send to a notification") + ErrNotSupported = errors.New("not supported") +) +// Error types defined below are the built-in JSON-RPC errors. var ( _ Error = new(ErrorMethodNotFound) _ Error = new(ErrorSubscriptionNotFound) @@ -25,12 +32,6 @@ const ( ErrorCodeJrpc = -42000 ) -var ( - ErrIllegalExtraField = errors.New("invalid extra field") - ErrSendAlreadyCalled = errors.New("send already called") - ErrCantSendNotification = errors.New("can't send to a notification") -) - // Error wraps RPC errors, which contain an error code in addition to the message. type Error interface { Error() string // returns the message diff --git a/pkg/jsonrpc/features.go b/pkg/jsonrpc/features.go new file mode 100644 index 0000000000000000000000000000000000000000..4deef67ce3af8895ea2c1c45f6c800a067d0f7cd --- /dev/null +++ b/pkg/jsonrpc/features.go @@ -0,0 +1,5 @@ +package jsonrpc + +type Hijacker interface { + Hijack() (send MessageStreamer, notify MessageStreamer, err error) +} diff --git a/pkg/jsonrpc/reqresp.go b/pkg/jsonrpc/reqresp.go index 1cf53e7564319191edfcbc17f52d8664da39b1b0..edca32ebece2116fcf90a957a6eaf5059797afa5 100644 --- a/pkg/jsonrpc/reqresp.go +++ b/pkg/jsonrpc/reqresp.go @@ -13,12 +13,6 @@ type ResponseWriter interface { Notify(method string, v any) error } -type StreamingResponseWriter interface { - ResponseWriter - SendStream(func(MessageStreamer) error) error - NotifyStream(func(MessageStreamer) error) error -} - type Request struct { ID *ID `json:"id,omitempty"` Method string `json:"method,omitempty"` diff --git a/pkg/jsonrpc/responsecontroller.go b/pkg/jsonrpc/responsecontroller.go new file mode 100644 index 0000000000000000000000000000000000000000..b821dd8294e74b69828a307b1c8bd0cb1844123d --- /dev/null +++ b/pkg/jsonrpc/responsecontroller.go @@ -0,0 +1,35 @@ +package jsonrpc + +import ( + "fmt" +) + +type ResponseController struct { + rw ResponseWriter +} + +func NewResponseController(rw ResponseWriter) *ResponseController { + return &ResponseController{rw} +} + +type rwUnwrapper interface { + Unwrap() ResponseWriter +} + +func (c *ResponseController) Hijack() (sender MessageStreamer, notify MessageStreamer, err error) { + rw := c.rw + for { + switch t := rw.(type) { + case Hijacker: + return t.Hijack() + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, nil, errNotSupported() + } + } +} + +func errNotSupported() error { + return fmt.Errorf("%w", ErrNotSupported) +} diff --git a/pkg/server/batching.go b/pkg/server/batching.go index e69c2dddf3c20fdaca8bc65bb83c1c7f84067a6a..9a70bdadbdd87926c8bca7837334f4c34f382f83 100644 --- a/pkg/server/batching.go +++ b/pkg/server/batching.go @@ -53,7 +53,6 @@ func serveBatch(ctx context.Context, returnWg := sync.WaitGroup{} returnWg.Add(len(incoming)) for _, v := range incoming { - canNext := make(chan struct{}) // create the response writer om, omerr := produceOutputMessage(v) rw := &streamingRespWriter{ @@ -65,9 +64,6 @@ func serveBatch(ctx context.Context, } if rw.id != nil { totalRequests += 1 - rw.done = func() { - close(canNext) - } } req := jsonrpc.NewRawRequest( ctx, @@ -76,16 +72,14 @@ func serveBatch(ctx context.Context, om.Params, ) req.Peer = r.peerinfo - go func() { + run := func() { defer returnWg.Done() handler.ServeRPC(rw, req) if rw.sendCalled == false && rw.id != nil { rw.Send(jsonrpc.Null, nil) } - }() - if rw.id != nil { - <-canNext } + run() } err = ansBatch.Close() diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go index 6e9911d316b6c40c9cccd5f2b37da3d84b498413..4551818176f55e2ac5a40b09bb471efdfddfde3f 100644 --- a/pkg/server/responsewriter.go +++ b/pkg/server/responsewriter.go @@ -19,41 +19,33 @@ type streamingRespWriter struct { // the id to write the response with id *jsonrpc.ID - // a function that is called on the first call to send - // it's optional - done func() - // if set, will ensure that send will always send this error, instead of whatever send does err error // marks whether or not send was called. it may only be called once sendCalled bool + // marks whether or not hijack was called + hijackCalled bool } -func (c *streamingRespWriter) SendStream(fn func(jsonrpc.MessageStreamer) error) error { - if c.sendCalled { - return jsonrpc.ErrSendAlreadyCalled +func (c *streamingRespWriter) Hijack() (sender jsonrpc.MessageStreamer, notify jsonrpc.MessageStreamer, err error) { + if c.hijackCalled { + return nil, nil, jsonrpc.ErrHijackAlreadyCalled } + c.hijackCalled = true c.sendCalled = true - if c.done != nil { - defer c.done() - } - return fn(c.sendStream) -} - -func (c *streamingRespWriter) NotifyStream(fn func(jsonrpc.MessageStreamer) error) error { - return fn(c.notifyStream) + return c.sendStream, c.notifyStream, nil } func (c *streamingRespWriter) Send(v any, e error) (err error) { + if c.hijackCalled { + return jsonrpc.ErrHijackAlreadyCalled + } if c.id == nil { return jsonrpc.ErrCantSendNotification } if c.sendCalled { return jsonrpc.ErrSendAlreadyCalled } - if c.done != nil { - defer c.done() - } c.sendCalled = true sentErr := c.err // only override error if not already set @@ -82,6 +74,10 @@ func (c *streamingRespWriter) Send(v any, e error) (err error) { } func (c *streamingRespWriter) Notify(method string, v any) error { + + if c.hijackCalled { + return jsonrpc.ErrHijackAlreadyCalled + } msg, err := c.notifyStream.NewMessage(c.ctx) if err != nil { return err