diff --git a/pkg/jsonrpc/message.go b/pkg/jsonrpc/message.go index dfaeacb8d47c6f3d5b418e9539b2c165bed109bb..f6b96a398503739dd5a7fe006aaee8307e2f7a83 100644 --- a/pkg/jsonrpc/message.go +++ b/pkg/jsonrpc/message.go @@ -4,56 +4,177 @@ import ( "encoding/json" "io" - "github.com/go-faster/jx" + "golang.org/x/net/context" + "golang.org/x/sync/semaphore" ) // MessageStream is a writer used to write jsonrpc message to a stream type MessageStream struct { w io.Writer - jx *jx.Writer -} - -func NewStream(w io.Writer) (*MessageStream, error) { - enc := jx.GetWriter() - defer jx.PutWriter(enc) - enc.Grow(4096) - enc.ResetWriter(w) - enc.ObjStart() - enc.FieldStart("jsonrpc") - enc.Str("2.0") - enc.Close() + mu *semaphore.Weighted +} + +func NewStream(w io.Writer) *MessageStream { return &MessageStream{ w: w, - jx: enc, - }, nil + mu: semaphore.NewWeighted(1), + } } -func (m *MessageStream) Field(name string, value json.RawMessage) error { - m.jx.ResetWriter(m.w) - m.jx.Comma() - m.jx.FieldStart(name) - m.jx.Raw(value) - return m.jx.Close() +type flusher interface { + Flush() error } -// Result returns a writecloser that writes to a result field -func (m *MessageStream) Result() (io.Writer, error) { - m.jx.ResetWriter(m.w) - m.jx.Comma() - m.jx.FieldStart("result") - m.jx.Close() - return &MessageWriter{w: m.w}, nil +func flushIfFlusher(w io.Writer) error { + if val, ok := w.(flusher); ok { + return val.Flush() + } + return nil } -func (m *MessageStream) Close() error { - _, err := m.w.Write([]byte("}")) - return err +// sends a flush in order to send an empty payload +func (m *MessageStream) Flush(ctx context.Context) error { + err := m.mu.Acquire(ctx, 1) + if err != nil { + return err + } + defer m.mu.Release(1) + return flushIfFlusher(m.w) } type MessageWriter struct { + w io.Writer + mu *semaphore.Weighted +} + +// NewMessage starts a new message and acquires the write lock. +// to free the write lock, you must call *MessageWriter.Close() +// the lock MUST be closed if and only if err == nil +func (m *MessageStream) NewMessage(ctx context.Context) (*MessageWriter, error) { + if m.mu != nil { + err := m.mu.Acquire(ctx, 1) + if err != nil { + return nil, err + } + } + _, err := m.w.Write([]byte(`{"jsonrpc":"2.0"`)) + if err != nil { + if m.mu != nil { + m.mu.Release(1) + } + return nil, err + } + return &MessageWriter{ + w: m.w, + mu: m.mu, + }, nil +} + +// close must be called when you are done writing the message. +// it releases the write lock +func (m *MessageWriter) Close() error { + if m.mu != nil { + defer m.mu.Release(1) + } + _, err := m.w.Write([]byte("}")) + if err != nil { + return err + } + return flushIfFlusher(m.w) +} + +func (m *MessageWriter) Field(name string, value json.RawMessage) error { + _, err := m.w.Write([]byte(`,"` + name + `":`)) + if err != nil { + return err + } + _, err = m.w.Write(value) + if err != nil { + return err + } + return nil +} + +// Result returns a writer that writes to a result field +func (m *MessageWriter) Result() (io.Writer, error) { + _, err := m.w.Write([]byte(`,"result":`)) + if err != nil { + return nil, err + } + return &ResultWriter{w: m.w}, nil +} + +type BatchWriter struct { + w io.Writer + mu *semaphore.Weighted + ms *MessageStream + isNotFirst bool +} + +type writer struct { + w io.Writer +} + +func (w *writer) Write(p []byte) (n int, err error) { + return w.w.Write(p) +} + +// Start writing a batch to the stream. this function acquires the lock +// caller MUST call Close() on the BatchWriter iff err == nil +func (m *MessageStream) NewBatch(ctx context.Context) (*BatchWriter, error) { + if m.mu != nil { + err := m.mu.Acquire(ctx, 1) + if err != nil { + return nil, err + } + } + _, err := m.w.Write([]byte("[")) + if err != nil { + if m.mu != nil { + m.mu.Release(1) + } + return nil, err + } + return &BatchWriter{ + w: m.w, + ms: &MessageStream{ + w: &writer{m.w}, + }, + mu: m.mu, + }, nil +} + +// Writes the next element in the batch. Note that the messagewriter is not thread safe +func (m *BatchWriter) Next(ctx context.Context) (*MessageWriter, error) { + if m.isNotFirst == false { + m.isNotFirst = true + } else { + // write comma if not the first element + _, err := m.w.Write([]byte(",")) + if err != nil { + return nil, err + } + } + return m.ms.NewMessage(ctx) +} + +// close must be called when you are done writing the batch. +// it releases the write lock +func (m *BatchWriter) Close() error { + if m.mu != nil { + defer m.mu.Release(1) + } + _, err := m.w.Write([]byte("]")) + if err != nil { + return err + } + return flushIfFlusher(m.w) +} + +type ResultWriter struct { w io.Writer } -func (m *MessageWriter) Write(p []byte) (n int, err error) { +func (m *ResultWriter) Write(p []byte) (n int, err error) { return m.w.Write(p) } diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go index eede8f9a7b30daea2de7fdbb897d6c4f8365ea50..fd8d7c39b0bb99438c742fbae16e01b79025ad9a 100644 --- a/pkg/server/responsewriter.go +++ b/pkg/server/responsewriter.go @@ -49,33 +49,27 @@ func (c *streamingRespWriter) Send(v any, e error) (err error) { if v != nil { ce.v = v } - err = c.cr.mu.Acquire(c.ctx, 1) + msg, err := c.cr.stream.NewMessage(c.ctx) if err != nil { return err } - defer c.cr.mu.Release(1) - if c.err != nil { - e = c.err - } - if err = c.cr.send(c.ctx, ce); err != nil { - return err - } - if err = c.cr.remote.Flush(); err != nil { + defer msg.Close() + if err = send(ce, msg); err != nil { return err } return nil } func (c *streamingRespWriter) Notify(method string, v any) error { - err := c.cr.mu.Acquire(c.ctx, 1) + msg, err := c.cr.stream.NewMessage(c.ctx) if err != nil { return err } - defer c.cr.mu.Release(1) - err = c.cr.notify(c.ctx, ¬ifyEnv{ + defer msg.Close() + err = c.cr.notify(¬ifyEnv{ method: method, dat: v, - }) + }, msg) if err != nil { return err } diff --git a/pkg/server/rw_batch.go b/pkg/server/rw_batch.go index e5bb473db52c721ce79d392941d124ebc146b77b..2a327cbbb5405354e56d9d8fbe3ab8686559f5a2 100644 --- a/pkg/server/rw_batch.go +++ b/pkg/server/rw_batch.go @@ -63,19 +63,15 @@ func (c *batchingRespWriter) Send(v any, e error) (err error) { } func (c *batchingRespWriter) Notify(method string, v any) error { - err := c.cr.mu.Acquire(c.ctx, 1) + msg, err := c.cr.stream.NewMessage(c.ctx) if err != nil { return err } - defer c.cr.mu.Release(1) - err = c.cr.notify(c.ctx, ¬ifyEnv{ + defer msg.Close() + err = c.cr.notify(¬ifyEnv{ method: method, dat: v, - }) - if err != nil { - return err - } - err = c.cr.remote.Flush() + }, msg) if err != nil { return err } diff --git a/pkg/server/server.go b/pkg/server/server.go index 67ea3edce19ddcd3dcc4fa41047dfc21bc098b23..0fbbf71c9dbadb88defc9af6f38f150266bfeed8 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -4,15 +4,10 @@ import ( "context" "encoding/json" "errors" - "io" "sync" - "golang.org/x/sync/semaphore" - "gfx.cafe/open/jrpc/pkg/jjson" "gfx.cafe/open/jrpc/pkg/jsonrpc" - - "github.com/go-faster/jx" ) // Server is an RPC server. @@ -37,9 +32,11 @@ func NewServer(r jsonrpc.Handler) *Server { func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error { defer remote.Close() - sema := semaphore.NewWeighted(1) + stream := jsonrpc.NewStream(remote) // add a cancel to the context so we can cancel all the child tasks on return - ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo())) + ctx = ContextWithPeerInfo(ctx, remote.PeerInfo()) + ctx = ContextWithMessageStream(ctx, stream) + ctx, cn := context.WithCancel(ctx) defer cn() allErrs := []error{} @@ -58,7 +55,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er responder := &callResponder{ remote: remote, batch: batch, - mu: sema, + stream: stream, } err = s.serve(ctx, incoming, responder) if err != nil { @@ -109,13 +106,8 @@ func (s *Server) serveSingle(ctx context.Context, ) req.Peer = r.remote.PeerInfo() if rw.msg.ID == nil { - // all notification, so immediately flush - err := r.mu.Acquire(ctx, 1) - if err != nil { - return err - } - defer r.mu.Release(1) - err = r.remote.Flush() + // all notification, so immediately flush a response + err := r.stream.Flush(ctx) if err != nil { return err } @@ -155,20 +147,15 @@ func (s *Server) serveBatch(ctx context.Context, // check for empty batch if r.batch && len(incoming) == 0 { // if it is empty batch, send the empty batch error and immediately return - err := r.mu.Acquire(ctx, 1) + mw, err := r.stream.NewMessage(ctx) if err != nil { return err } - defer r.mu.Release(1) - err = r.send(ctx, &callEnv{ - id: jsonrpc.NewNullIDPtr(), - err: jsonrpc.NewInvalidRequestError("empty batch"), - }) - if err != nil { + defer mw.Close() + if err := mw.Field("id", jsonrpc.Null); err != nil { return err } - err = r.remote.Flush() - if err != nil { + if err := mw.Field("error", jsonrpc.MarshalError(jsonrpc.NewInvalidRequestError("empty batch"))); err != nil { return err } return nil @@ -229,61 +216,55 @@ func (s *Server) serveBatch(ctx context.Context, if totalRequests > 0 { // TODO: channel? respWg.Wait() - err := r.mu.Acquire(ctx, 1) - if err != nil { - return err - } - defer r.mu.Release(1) - // write them, one by one - _, err = r.remote.Write([]byte{'['}) - if err != nil { - return err - } - for i, v := range batchResults { - err = r.send(ctx, &callEnv{ - v: v.payload, - err: v.err, - id: v.msg.ID, - }) + err := func() error { + batch, err := r.stream.NewBatch(ctx) if err != nil { return err } - // write the comma or ] - char := ',' - if i == len(batchResults)-1 { - char = ']' - } - _, err = r.remote.Write([]byte{byte(char)}) - if err != nil { - return err + defer batch.Close() + // write them, one by one + for _, v := range batchResults { + err := func() error { + msg, err := batch.Next(ctx) + if err != nil { + return err + } + defer msg.Close() + err = send(&callEnv{ + v: v.payload, + err: v.err, + id: v.msg.ID, + }, msg) + if err != nil { + return err + } + return nil + }() + if err != nil { + return err + } } - } - err = r.remote.Flush() + return nil + }() if err != nil { return err } } else if totalRequests == 0 { - // all notification, so immediately flush - err := r.mu.Acquire(ctx, 1) - if err != nil { - return err - } - defer r.mu.Release(1) - err = r.remote.Flush() + // all notification, so immediately flush, and that's the whole message + err := r.stream.Flush(ctx) if err != nil { return err } } + // wait for the returnWg to return returnWg.Wait() return nil } type callResponder struct { remote jsonrpc.ReaderWriter - mu *semaphore.Weighted - - batch bool - batchStarted bool + stream *jsonrpc.MessageStream + batch bool } type callEnv struct { @@ -292,13 +273,7 @@ type callEnv struct { id *jsonrpc.ID } -func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { - w := c.remote - s, err := jsonrpc.NewStream(w) - if err != nil { - return err - } - defer s.Close() +func send(env *callEnv, s *jsonrpc.MessageWriter) (err error) { if env.id != nil { s.Field("id", env.id.RawMessage()) } @@ -311,6 +286,7 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { if err != nil { return err } + // if is nil, just write null if env.v == nil { _, err := wr.Write(jsonrpc.Null) if err != nil { @@ -318,25 +294,22 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { } return nil } + // if is not nil, do switch statement switch cast := (env.v).(type) { case json.RawMessage: if len(cast) == 0 { + _, err := wr.Write(jsonrpc.Null) + if err != nil { + return err + } } else { _, err := wr.Write(cast) if err != nil { return err } } - case *io.PipeReader: - _, err := io.Copy(wr, cast) - if err != nil { - return err - } - cast.Close() - case func(e io.Writer) error: - err = cast(wr) default: - err = jjson.Encode(w, cast) + err = jjson.Encode(wr, cast) } return nil } @@ -346,26 +319,40 @@ type notifyEnv struct { dat any } -func (c *callResponder) notify(ctx context.Context, env *notifyEnv) (err error) { - msg := &jsonrpc.Message{} - // allocate a temp buffer for this packet - buf := jjson.GetBuf() - defer jjson.PutBuf(buf) - err = jjson.Encode(buf, env.dat) +func (c *callResponder) notify(env *notifyEnv, s *jsonrpc.MessageWriter) (err error) { + err = s.Field("method", []byte(`"`+env.method+`"`)) if err != nil { - msg.Error = err - } else { - msg.Params = buf.Bytes() + return err } - // add the method - msg.Method = env.method - enc := jx.GetEncoder() - defer jx.PutEncoder(enc) - enc.Grow(4096) - enc.ResetWriter(c.remote) - err = jsonrpc.MarshalMessage(msg, enc) + // if there is no error, we try to marshal the result + wr, err := s.Result() if err != nil { return err } - return enc.Close() + // if is nil, just write null + if env.dat == nil { + _, err := wr.Write(jsonrpc.Null) + if err != nil { + return err + } + return nil + } + // if is not nil, do switch statement + switch cast := (env.dat).(type) { + case json.RawMessage: + if len(cast) == 0 { + _, err := wr.Write(jsonrpc.Null) + if err != nil { + return err + } + } else { + _, err := wr.Write(cast) + if err != nil { + return err + } + } + default: + err = jjson.Encode(wr, cast) + } + return nil } diff --git a/pkg/server/util.go b/pkg/server/util.go index b405962a5cf399d7422d226727d685d46f1596ce..d985ecf372b4ae21fa4766f81729391e79e3ae7a 100644 --- a/pkg/server/util.go +++ b/pkg/server/util.go @@ -8,6 +8,8 @@ import ( type peerInfoContextKey struct{} +type messageStreamContextKeyType struct{} + // PeerInfoFromContext returns information about the client's network connection. // Use this with the context passed to RPC method handler functions. // @@ -19,3 +21,11 @@ func PeerInfoFromContext(ctx context.Context) jsonrpc.PeerInfo { func ContextWithPeerInfo(ctx context.Context, c jsonrpc.PeerInfo) context.Context { return context.WithValue(ctx, peerInfoContextKey{}, c) } + +func MessageStreamFromContext(ctx context.Context) *jsonrpc.MessageStream { + info, _ := ctx.Value(messageStreamContextKeyType{}).(*jsonrpc.MessageStream) + return info +} +func ContextWithMessageStream(ctx context.Context, c *jsonrpc.MessageStream) context.Context { + return context.WithValue(ctx, messageStreamContextKeyType{}, c) +}