diff --git a/pkg/server/batching.go b/pkg/server/batching.go new file mode 100644 index 0000000000000000000000000000000000000000..e69c2dddf3c20fdaca8bc65bb83c1c7f84067a6a --- /dev/null +++ b/pkg/server/batching.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + "sync" + + "gfx.cafe/open/jrpc/pkg/jsonrpc" + "github.com/mailgun/multibuf" +) + +// serving batches is a bit complicated and we don't even use it +func serveBatch(ctx context.Context, + incoming []*jsonrpc.Message, + r *callResponder, + handler jsonrpc.Handler, +) error { + // check for empty batch + if r.batch && len(incoming) == 0 { + // if it is empty batch, send the empty batch error and immediately return + mw, err := r.stream.NewMessage(ctx) + if err != nil { + return err + } + defer mw.Close() + if err := mw.Field("id", jsonrpc.Null); err != nil { + return err + } + if err := mw.Field("error", jsonrpc.MarshalError(jsonrpc.NewInvalidRequestError("empty batch"))); err != nil { + return err + } + return nil + } + + totalRequests := 0 + // populate the envelope we are about to send. this is synchronous pre-prpcessing + ansBuf, err := multibuf.NewWriterOnce( + // store up to 16mb per batch in memory + multibuf.MemBytes(16*1024*1024), + // store up to 256gb per batch on disk + multibuf.MaxBytes(256*1204*1024*1024), + ) + defer ansBuf.Close() + if err != nil { + return err + } + ansStream := jsonrpc.NewStream(ansBuf) + ansBatch, err := ansStream.NewBatch(ctx) + if err != nil { + return err + } + + // create a waitgroup for when every handler returns + returnWg := sync.WaitGroup{} + returnWg.Add(len(incoming)) + for _, v := range incoming { + canNext := make(chan struct{}) + // create the response writer + om, omerr := produceOutputMessage(v) + rw := &streamingRespWriter{ + ctx: ctx, + sendStream: ansBatch, + notifyStream: r.stream, + id: om.ID, + err: omerr, + } + if rw.id != nil { + totalRequests += 1 + rw.done = func() { + close(canNext) + } + } + req := jsonrpc.NewRawRequest( + ctx, + om.ID, + om.Method, + om.Params, + ) + req.Peer = r.peerinfo + go func() { + defer returnWg.Done() + handler.ServeRPC(rw, req) + if rw.sendCalled == false && rw.id != nil { + rw.Send(jsonrpc.Null, nil) + } + }() + if rw.id != nil { + <-canNext + } + } + + err = ansBatch.Close() + if err != nil { + return err + } + + mr, err := ansBuf.Reader() + if err != nil { + return err + } + defer mr.Close() + + if totalRequests > 0 { + // TODO: channel? + err := r.stream.ReadFrom(ctx, mr) + if err != nil { + return err + } + } else if totalRequests == 0 { + // all notification, so immediately flush, and that's the whole message + err := r.stream.Flush(ctx) + if err != nil { + return err + } + } + // wait for the returnWg to return + returnWg.Wait() + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 300652dab2a2a729ed338ac9085e40214fcc55d9..8b822c7c6f430f8643a868be73a7712192ab2ada 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,9 +3,7 @@ package server import ( "context" "errors" - "sync" - "github.com/mailgun/multibuf" "golang.org/x/sync/errgroup" "gfx.cafe/open/jrpc/pkg/jsonrpc" @@ -41,7 +39,7 @@ func ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter, handler jsonrp // read messages from the stream synchronously incoming, batch, err := remote.ReadBatch(ctx) if err != nil { - if errors.Is(err, jsonrpc.ErrNoMoreBatches) { + if errors.Is(err, jsonrpc.ErrNoMoreBatches) || errors.Is(err, context.Canceled) { return } select { @@ -81,6 +79,12 @@ func ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter, handler jsonrp } } +type callResponder struct { + peerinfo jsonrpc.PeerInfo + stream *jsonrpc.MessageStream + batch bool +} + func serve(ctx context.Context, incoming []*jsonrpc.Message, r *callResponder, @@ -98,14 +102,14 @@ func serveSingle(ctx context.Context, r *callResponder, handler jsonrpc.Handler, ) error { + om, omerr := produceOutputMessage(incoming) rw := &streamingRespWriter{ ctx: ctx, sendStream: r.stream, notifyStream: r.stream, + id: om.ID, + err: omerr, } - om, omerr := produceOutputMessage(incoming) - rw.id = om.ID - rw.err = omerr req := jsonrpc.NewRawRequest( ctx, rw.id, @@ -127,114 +131,6 @@ func serveSingle(ctx context.Context, return nil } -func serveBatch(ctx context.Context, - incoming []*jsonrpc.Message, - r *callResponder, - handler jsonrpc.Handler, -) error { - // check for empty batch - if r.batch && len(incoming) == 0 { - // if it is empty batch, send the empty batch error and immediately return - mw, err := r.stream.NewMessage(ctx) - if err != nil { - return err - } - defer mw.Close() - if err := mw.Field("id", jsonrpc.Null); err != nil { - return err - } - if err := mw.Field("error", jsonrpc.MarshalError(jsonrpc.NewInvalidRequestError("empty batch"))); err != nil { - return err - } - return nil - } - - totalRequests := 0 - // populate the envelope we are about to send. this is synchronous pre-prpcessing - ansBuf, err := multibuf.NewWriterOnce( - // store up to 16mb per batch in memory - multibuf.MemBytes(16*1024*1024), - // store up to 256gb per batch on disk - multibuf.MaxBytes(256*1204*1024*1024), - ) - defer ansBuf.Close() - if err != nil { - return err - } - ansStream := jsonrpc.NewStream(ansBuf) - ansBatch, err := ansStream.NewBatch(ctx) - if err != nil { - return err - } - - // create a waitgroup for when every handler returns - returnWg := sync.WaitGroup{} - returnWg.Add(len(incoming)) - for _, v := range incoming { - canNext := make(chan struct{}) - // create the response writer - rw := &streamingRespWriter{ - ctx: ctx, - sendStream: ansBatch, - notifyStream: r.stream, - } - om, omerr := produceOutputMessage(v) - rw.id = om.ID - rw.err = omerr - if rw.id != nil { - totalRequests += 1 - rw.done = func() { - close(canNext) - } - } - req := jsonrpc.NewRawRequest( - ctx, - om.ID, - om.Method, - om.Params, - ) - req.Peer = r.peerinfo - go func() { - defer returnWg.Done() - handler.ServeRPC(rw, req) - if rw.sendCalled == false && rw.id != nil { - rw.Send(jsonrpc.Null, nil) - } - }() - if rw.id != nil { - <-canNext - } - } - - err = ansBatch.Close() - if err != nil { - return err - } - - mr, err := ansBuf.Reader() - if err != nil { - return err - } - defer mr.Close() - - if totalRequests > 0 { - // TODO: channel? - err := r.stream.ReadFrom(ctx, mr) - if err != nil { - return err - } - } else if totalRequests == 0 { - // all notification, so immediately flush, and that's the whole message - err := r.stream.Flush(ctx) - if err != nil { - return err - } - } - // wait for the returnWg to return - returnWg.Wait() - return nil -} - func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message, err error) { // a nil incoming message means return an invalid request. if inputMessage == nil { @@ -255,15 +151,3 @@ func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message, return } - -type callResponder struct { - peerinfo jsonrpc.PeerInfo - stream *jsonrpc.MessageStream - batch bool -} - -type callEnv struct { - v any - err error - id *jsonrpc.ID -}