diff --git a/rpc/handler.go b/rpc/handler.go index 3f75a6dae9226de727dd2c4dbe338ddfde84878b..734b27b6402722537cd2eb11cea669596d8a1d8f 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -17,7 +17,6 @@ package rpc import ( - "bytes" "context" "encoding/json" "fmt" @@ -28,7 +27,6 @@ import ( "time" jsoniter "github.com/json-iterator/go" - "github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/log" ) @@ -104,6 +102,12 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg * // handleBatch executes all messages in a batch and returns the responses. func (h *handler) handleBatch(msgs []*jsonrpcMessage, stream *jsoniter.Stream) { + needWriteStream := false + if stream == nil { + stream = jsoniter.NewStream(jsoniter.ConfigDefault, nil, 4096) + needWriteStream = true + } + // Emit error response for empty batches: if len(msgs) == 0 { h.startCallProc(func(cp *callProc) { @@ -124,72 +128,74 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage, stream *jsoniter.Stream) { } // Process calls on a goroutine because they may block indefinitely: h.startCallProc(func(cp *callProc) { - allMethodsAreThreadSafe := true // only if all methods in batch are pass next criteria - for i := range calls { - if calls[i].isSubscribe() { - allMethodsAreThreadSafe = false - break + stream.WriteArrayStart() + firstResponse := true + // All goroutines will place results right to this array. Because requests order must match reply orders. + // Bounded parallelism pattern explanation https://blog.golang.org/pipelines#TOC_9. + boundedConcurrency := make(chan struct{}, h.maxBatchConcurrency) + defer close(boundedConcurrency) + wg := sync.WaitGroup{} + wg.Add(len(msgs)) + streamMutex := sync.Mutex{} + + writeToStream := func(buffer []byte) { + if len(buffer) == 0 { + return } - cb := h.reg.callback(calls[i].Method) - if cb != nil && cb.streamable { // cb == nil: means no such method and this case is thread-safe - allMethodsAreThreadSafe = false - break + + streamMutex.Lock() + defer streamMutex.Unlock() + + if !firstResponse { + stream.WriteMore() } - } - if !allMethodsAreThreadSafe && stream == nil { - _ = h.conn.writeJSON(context.Background(), jsonrpcMessage{Version: vsn, ID: null, Error: &jsonError{ - Code: -32601, - Message: "streamable methods are not supported on websockets. help us to implement", - }}) - return + stream.Write(buffer) + firstResponse = false } - answers := make([]interface{}, 0, len(msgs)) - if allMethodsAreThreadSafe { - // All goroutines will place results right to this array. Because requests order must match reply orders. - answersWithNils := make([]*jsonrpcMessage, len(msgs)) - // Bounded parallelism pattern explanation https://blog.golang.org/pipelines#TOC_9. - boundedConcurrency := make(chan struct{}, h.maxBatchConcurrency) - defer close(boundedConcurrency) - wg := sync.WaitGroup{} - wg.Add(len(msgs)) - for i := range calls { - boundedConcurrency <- struct{}{} - go func(i int) { - defer func() { - wg.Done() - <-boundedConcurrency - }() - - answersWithNils[i] = h.handleCallMsg(cp, calls[i], stream) - }(i) - } - wg.Wait() - for _, answer := range answersWithNils { - if answer != nil { - answers = append(answers, answer) + for i := range calls { + if calls[i].isSubscribe() { + // Force subscribe call to work in non-streaming mode + response := h.handleCallMsg(cp, calls[i], nil) + if response != nil { + b, _ := json.Marshal(response) + writeToStream(b) } } - } else { - answers = make([]interface{}, 0, len(msgs)) - buf := bytes.NewBuffer(nil) - stream := jsoniter.NewStream(jsoniter.ConfigDefault, buf, 4096) - for _, msg := range calls { - buf.Reset() - stream.Reset(buf) - if answer := h.handleCallMsg(cp, msg, stream); answer != nil { - answers = append(answers, answer) + boundedConcurrency <- struct{}{} + go func(i int) { + defer func() { + wg.Done() + <-boundedConcurrency + }() + cb := h.reg.callback(calls[i].Method) + var response *jsonrpcMessage + if cb != nil && cb.streamable { // cb == nil: means no such method and this case is thread-safe + batchStream := jsoniter.NewStream(jsoniter.ConfigDefault, nil, 4096) + response = h.handleCallMsg(cp, calls[i], batchStream) + writeToStream(batchStream.Buffer()) } else { - if buf.Len() > 0 { - answers = append(answers, json.RawMessage(common.CopyBytes(buf.Bytes()))) - } + response = h.handleCallMsg(cp, calls[i], stream) } - } + // Marshal inside goroutine (parallel) + if response != nil { + buffer, _ := json.Marshal(response) + writeToStream(buffer) + } + }(i) } - h.addSubscriptions(cp.notifiers) - if len(answers) > 0 { - h.conn.writeJSON(cp.ctx, answers) + wg.Wait() + + stream.WriteArrayEnd() + stream.Flush() + + if needWriteStream { + h.conn.writeJSON(cp.ctx, json.RawMessage(stream.Buffer())) + } else { + stream.Write([]byte("\n")) } + + h.addSubscriptions(cp.notifiers) for _, n := range cp.notifiers { n.activate() } @@ -202,10 +208,21 @@ func (h *handler) handleMsg(msg *jsonrpcMessage, stream *jsoniter.Stream) { return } h.startCallProc(func(cp *callProc) { + needWriteStream := false + if stream == nil { + stream = jsoniter.NewStream(jsoniter.ConfigDefault, nil, 4096) + needWriteStream = true + } answer := h.handleCallMsg(cp, msg, stream) h.addSubscriptions(cp.notifiers) if answer != nil { - h.conn.writeJSON(cp.ctx, answer) + buffer, _ := json.Marshal(answer) + stream.Write(json.RawMessage(buffer)) + } + if needWriteStream { + h.conn.writeJSON(cp.ctx, json.RawMessage(stream.Buffer())) + } else { + stream.Write([]byte("\n")) } for _, n := range cp.notifiers { n.activate() @@ -499,7 +516,6 @@ func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *cal stream.WriteObjectEnd() } stream.WriteObjectEnd() - stream.Write([]byte("\n")) stream.Flush() return nil } else { diff --git a/rpc/server_test.go b/rpc/server_test.go index 7c226cd7497fc1d5c4133d4255222a212e482514..157aca11b67b5ca2e8c6983f8f11ab1ca6b6c7d6 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -19,10 +19,12 @@ package rpc import ( "bufio" "bytes" + "encoding/json" "io" "io/ioutil" "net" "path/filepath" + "sort" "strings" "testing" "time" @@ -102,6 +104,20 @@ func runTestScript(t *testing.T, file string) { t.Fatalf("read error: %v", err) } sent = strings.TrimRight(sent, "\r\n") + msgs, batch := parseMessage(json.RawMessage(sent)) + if batch { + sort.Slice(msgs, func(i, j int) bool { + return string(msgs[i].ID) < string(msgs[j].ID) + }) + b, _ := json.Marshal(msgs) + sent = string(b) + msgs, _ = parseMessage(json.RawMessage(want)) + sort.Slice(msgs, func(i, j int) bool { + return string(msgs[i].ID) < string(msgs[j].ID) + }) + b, _ = json.Marshal(msgs) + want = string(b) + } if sent != want { t.Errorf("wrong line from server\ngot: %s\nwant: %s", sent, want) } diff --git a/rpc/testdata/reqresp-batch.js b/rpc/testdata/reqresp-batch.js index 977af76630996cc8a929f158bb3ea3cb64af776d..964bc8084012dfec4f3260918231ec5f0ada2b41 100644 --- a/rpc/testdata/reqresp-batch.js +++ b/rpc/testdata/reqresp-batch.js @@ -1,7 +1,7 @@ // There is no response for all-notification batches. --> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}] - +<-- [] // This test checks regular batch calls. --> [{"jsonrpc":"2.0","id":2,"method":"test_echo","params":[]}, {"jsonrpc":"2.0","id": 3,"method":"test_echo","params":["x",3]}]