diff --git a/pkg/codec/errors.go b/pkg/codec/errors.go index b9d1aad201964a61e6c870c27da1ee4938b45a17..ec8c092f6467968e3ded5f0f5f2d629e342272c6 100644 --- a/pkg/codec/errors.go +++ b/pkg/codec/errors.go @@ -124,13 +124,13 @@ type ErrorParse struct{ message string } func (e *ErrorParse) ErrorCode() int { return -32700 } func (e *ErrorParse) Error() string { return e.message } +// received message isn't a valid request +type ErrorInvalidRequest struct{ message string } + func NewInvalidRequestError(message string) *ErrorInvalidRequest { return &ErrorInvalidRequest{message: message} } -// received message isn't a valid request -type ErrorInvalidRequest struct{ message string } - func (e *ErrorInvalidRequest) ErrorCode() int { return -32600 } func (e *ErrorInvalidRequest) Error() string { return e.message } @@ -140,15 +140,23 @@ type ErrorInvalidMessage struct{ message string } func (e *ErrorInvalidMessage) ErrorCode() int { return -32700 } func (e *ErrorInvalidMessage) Error() string { return e.message } +// unable to decode supplied params, or an invalid number of parameters +type ErrorInvalidParams struct{ message string } + func NewInvalidParamsError(message string) *ErrorInvalidParams { return &ErrorInvalidParams{message: message} } +func (e *ErrorInvalidParams) ErrorCode() int { return -32602 } +func (e *ErrorInvalidParams) Error() string { return e.message } // unable to decode supplied params, or an invalid number of parameters -type ErrorInvalidParams struct{ message string } +type ErrorInternalError struct{ message string } -func (e *ErrorInvalidParams) ErrorCode() int { return -32602 } -func (e *ErrorInvalidParams) Error() string { return e.message } +func NewInternalError(message string) *ErrorInternalError { + return &ErrorInternalError{message: message} +} +func (e *ErrorInternalError) ErrorCode() int { return -32603 } +func (e *ErrorInternalError) Error() string { return e.message } // HTTPError is returned by client operations when the HTTP status code of the // response is not a 2xx status. diff --git a/pkg/codec/reqresp.go b/pkg/codec/reqresp.go index c00371aaa0414e184c68e9eb4f4e01036577d95a..3ceb677222eaa6929e73e7591141d2fd341a9e1e 100644 --- a/pkg/codec/reqresp.go +++ b/pkg/codec/reqresp.go @@ -2,7 +2,6 @@ package codec import ( "context" - "net/http" json "github.com/goccy/go-json" ) @@ -10,11 +9,8 @@ import ( // http.ResponseWriter interface, but for jrpc type ResponseWriter interface { Send(v any, err error) error - Header() http.Header - - SetExtraField(k string, v any) error - Notify(method string, v any) error + ExtraFields() ExtraFields } // BatchElem is an element in a batch request. diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go index ccb01c820b574a62c4dd1884c505e2d80fa2517a..967efbf19bc7cd0ea5d243f244d5c358390b4cac 100644 --- a/pkg/jrpctest/suites.go +++ b/pkg/jrpctest/suites.go @@ -128,10 +128,10 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { for i := range batch { a := batch[i] b := wantResult[i] - assert.EqualValuesf(t, a.Method, b.Method, "item %d", i) - assert.EqualValuesf(t, a.Result, b.Result, "item %d", i) - assert.EqualValuesf(t, a.Params, b.Params, "item %d", i) - assert.EqualValuesf(t, a.Error, b.Error, "item %d", i) + assert.EqualValuesf(t, b.Method, a.Method, "item %d", i) + assert.EqualValuesf(t, b.Result, a.Result, "item %d", i) + assert.EqualValuesf(t, b.Params, a.Params, "item %d", i) + assert.EqualValuesf(t, b.Error, a.Error, "item %d", i) } }) diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go index d39f7f3b51e14cc5c0f244345a75394023a21441..5367ae454d04c29ee2b6e26f417b06de2d14b25a 100644 --- a/pkg/server/responsewriter.go +++ b/pkg/server/responsewriter.go @@ -1,15 +1,10 @@ package server import ( - "bytes" "context" - "net/http" "sync" "gfx.cafe/open/jrpc/pkg/codec" - "gfx.cafe/util/go/bufpool" - "github.com/goccy/go-json" - "golang.org/x/sync/semaphore" ) // 16mb... should be more than enough for any batch. @@ -17,27 +12,22 @@ import ( // TODO: make this configurable const maxBatchSizeBytes = 1024 * 1024 * 1024 * 16 -var _ codec.ResponseWriter = (*callRespWriter)(nil) +var _ codec.ResponseWriter = (*streamingRespWriter)(nil) -// callRespWriter is NOT thread safe -type callRespWriter struct { +// streamingRespWriter is NOT thread safe +type streamingRespWriter struct { cr *callResponder msg *codec.Message ctx context.Context - noStream bool - doneMu *semaphore.Weighted - - payload json.RawMessage - err error + err error sendCalled bool - header http.Header mu sync.Mutex } -func (c *callRespWriter) Send(v any, e error) (err error) { +func (c *streamingRespWriter) Send(v any, e error) (err error) { c.mu.Lock() defer c.mu.Unlock() if c.msg.ID == nil { @@ -47,75 +37,41 @@ func (c *callRespWriter) Send(v any, e error) (err error) { return codec.ErrSendAlreadyCalled } c.sendCalled = true - // defer the sending of this for later - if c.doneMu != nil { - defer c.doneMu.Release(1) + ce := &callEnv{ + err: c.err, + id: c.msg.ID, + extrafields: c.msg.ExtraFields, + } + // only override error if not already set + if ce.err == nil { + ce.err = e } - // batch requests are not individually streamed. - // the reason is beacuse i couldn't think of a good way to implement it - // ultimately they need to be buffered. there's some optimistic multiplexing you can - // do, but that felt really complicated and not worth the time. - if c.noStream { - if c.err == nil { - c.err = e - } - if v != nil { - // json marshaling errors are reported to the handler - buf := bufpool.GlobalPool.GetStd() - w := newWriter(buf, maxBatchSizeBytes, false) - err = json.NewEncoder(w).Encode(v) - if err != nil { - return err - } - c.payload = json.RawMessage(bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})) - return nil - } - return nil + // only set value if value is not nil + if v != nil { + ce.v = &v } err = c.cr.mu.Acquire(c.ctx, 1) if err != nil { return err } - select { - case <-c.ctx.Done(): - return c.ctx.Err() - default: - } defer c.cr.mu.Release(1) if c.err != nil { e = c.err } - ce := &callEnv{ - err: e, - id: c.msg.ID, - extrafields: c.msg.ExtraFields, - } - if v != nil { - ce.v = &v - } - err = c.cr.send(c.ctx, ce) - if err != nil { + if err = c.cr.send(c.ctx, ce); err != nil { return err } - err = c.cr.remote.Flush() - if err != nil { + if err = c.cr.remote.Flush(); err != nil { return err } return nil } -func (c *callRespWriter) SetExtraField(k string, v any) error { - c.mu.Lock() - defer c.mu.Unlock() - c.msg.SetExtraField(k, v) - return nil -} - -func (c *callRespWriter) Header() http.Header { - return c.header +func (c *streamingRespWriter) ExtraFields() codec.ExtraFields { + return c.msg.ExtraFields } -func (c *callRespWriter) Notify(method string, v any) error { +func (c *streamingRespWriter) Notify(method string, v any) error { err := c.cr.mu.Acquire(c.ctx, 1) if err != nil { return err diff --git a/pkg/server/rw_batch.go b/pkg/server/rw_batch.go new file mode 100644 index 0000000000000000000000000000000000000000..ce7fbe01e96fa6c0de7d8e4c3157e94cb996ece1 --- /dev/null +++ b/pkg/server/rw_batch.go @@ -0,0 +1,87 @@ +package server + +import ( + "bytes" + "context" + "sync" + + "gfx.cafe/open/jrpc/pkg/codec" + "github.com/goccy/go-json" +) + +// batchingRespWriter is NOT thread safe +type batchingRespWriter struct { + cr *callResponder + msg *codec.Message + ctx context.Context + + wg *sync.WaitGroup + payload json.RawMessage + err error + + sendCalled bool + + mu sync.Mutex +} + +func (c *batchingRespWriter) Send(v any, e error) (err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.msg.ID == nil { + return codec.ErrCantSendNotification + } + if c.sendCalled { + return codec.ErrSendAlreadyCalled + } + c.sendCalled = true + if c.wg != nil { + defer c.wg.Done() + } + // if there is an error, and no c.err is set, and there is an e, then set c.err to e + if c.err == nil { + c.err = e + } + // batch requests are not individually streamed. + // the reason is beacuse i couldn't think of a good way to implement it + // ultimately they need to be buffered. there's some optimistic multiplexing you can + // do, but that felt really complicated and not worth the time. + if v != nil && c.err == nil { + buf := &bytes.Buffer{} + w := newWriter(buf, maxBatchSizeBytes, false) + err = json.NewEncoder(w).Encode(v) + if err != nil { + // the user just gets a generic error saying that the json is bad + c.err = codec.NewInternalError("server sent bad json") + // json marshaling errors are reported to the Send call, not the user + return err + } + c.payload = json.RawMessage(bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})) + return nil + } + return nil +} + +func (c *batchingRespWriter) ExtraFields() codec.ExtraFields { + return c.msg.ExtraFields +} + +func (c *batchingRespWriter) Notify(method string, v any) error { + err := c.cr.mu.Acquire(c.ctx, 1) + if err != nil { + return err + } + defer c.cr.mu.Release(1) + err = c.cr.notify(c.ctx, ¬ifyEnv{ + method: method, + dat: v, + extra: c.msg.ExtraFields, + }) + if err != nil { + return err + } + err = c.cr.remote.Flush() + if err != nil { + return err + } + return nil +} diff --git a/pkg/server/server.go b/pkg/server/server.go index e49209d5b5df5adf13e3032e6c204b316c022835..4e88eaa53f0d54af7751393d72f344f38f513eb2 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -59,7 +59,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) erro batch: batch, mu: sema, } - err = s.serveBatch(ctx, incoming, responder) + err = s.serve(ctx, incoming, responder) if err != nil { mu.Lock() defer mu.Unlock() @@ -80,6 +80,72 @@ func (s *Server) Shutdown(ctx context.Context) { s.cn() } +func (s *Server) serve(ctx context.Context, + incoming []*codec.Message, + r *callResponder, +) error { + if r.batch { + return s.serveBatch(ctx, incoming, r) + } else { + return s.serveSingle(ctx, incoming[0], r) + } +} + +func (s *Server) serveSingle(ctx context.Context, + incoming *codec.Message, + r *callResponder, +) error { + rw := &streamingRespWriter{ + ctx: ctx, + cr: r, + } + rw.msg, rw.err = produceOutputMessage(incoming) + req := codec.NewRequestFromMessage( + ctx, + rw.msg, + ) + req.Peer = r.remote.PeerInfo() + if rw.msg.ID == nil { + // all notification, so immediately flush + err := r.mu.Acquire(ctx, 1) + if err != nil { + return err + } + defer r.mu.Release(1) + err = r.remote.Flush() + if err != nil { + return err + } + } + s.services.ServeRPC(rw, req) + if rw.sendCalled == false && rw.msg.ID != nil { + rw.Send(codec.Null, nil) + } + return nil +} + +func produceOutputMessage(inputMessage *codec.Message) (out *codec.Message, err error) { + // a nil incoming message means return an invalid request. + if inputMessage == nil { + inputMessage = &codec.Message{ID: codec.NewNullIDPtr()} + err = codec.NewInvalidRequestError("invalid request") + } + out = inputMessage + out.ExtraFields = codec.ExtraFields{} + out.Error = nil + // zero length method is always invalid request + if len(out.Method) == 0 { + // assume if the method is not there AND the id is not there that it's an invalid REQUEST not notification + // this makes sure we add 1 to totalRequests + if out.ID == nil { + out.ID = codec.NewNullIDPtr() + } + err = codec.NewInvalidRequestError("invalid request") + } + + return +} + func (s *Server) serveBatch(ctx context.Context, incoming []*codec.Message, r *callResponder, @@ -106,80 +172,60 @@ func (s *Server) serveBatch(ctx context.Context, return nil } - rs := []*callRespWriter{} + rs := []*batchingRespWriter{} totalRequests := 0 // populate the envelope we are about to send. this is synchronous pre-prpcessing for _, v := range incoming { // create the response writer - rw := &callRespWriter{ + rw := &batchingRespWriter{ ctx: ctx, cr: r, } rs = append(rs, rw) - // a nil incoming message means return an invalid request. - if v == nil { - v = &codec.Message{ID: codec.NewNullIDPtr()} - rw.err = codec.NewInvalidRequestError("invalid request") - } - rw.msg = v - rw.msg.ExtraFields = codec.ExtraFields{} - rw.msg.Error = nil - // zero length method is always invalid request - if len(v.Method) == 0 { - // assume if the method is not there AND the id is not there that it's an invalid REQUEST not notification - // this makes sure we add 1 to totalRequests - if v.ID == nil { - v.ID = codec.NewNullIDPtr() - } - rw.err = codec.NewInvalidRequestError("invalid request") - } + rw.msg, rw.err = produceOutputMessage(v) // requests and malformed requests both count as requests - if v.ID != nil { + if rw.msg.ID != nil { totalRequests += 1 } } - var doneMu *semaphore.Weighted - doneMu = semaphore.NewWeighted(int64(totalRequests)) - err := doneMu.Acquire(ctx, int64(totalRequests)) - if err != nil { - return err - } - // create a waitgroup for everything - wg := sync.WaitGroup{} - wg.Add(len(rs)) + // create a waitgroup for when every handler returns + returnWg := sync.WaitGroup{} + returnWg.Add(len(rs)) // for each item in the envelope peerInfo := r.remote.PeerInfo() - batchResults := []*callRespWriter{} + batchResults := []*batchingRespWriter{} + + respWg := &sync.WaitGroup{} + respWg.Add(totalRequests) + for _, vRef := range rs { v := vRef - if r.batch { - v.noStream = true - if v.msg.ID != nil { - v.doneMu = doneMu - batchResults = append(batchResults, v) - } + if v.msg.ID != nil { + v.wg = respWg + batchResults = append(batchResults, v) } // now process each request in its own goroutine // TODO: stress test this. go func() { - defer wg.Done() + defer returnWg.Done() req := codec.NewRequestFromMessage( ctx, v.msg, ) req.Peer = peerInfo s.services.ServeRPC(v, req) + if v.sendCalled == false && v.err == nil { + v.Send(codec.Null, nil) + } }() } - if r.batch && totalRequests > 0 { - err = doneMu.Acquire(ctx, int64(totalRequests)) - if err != nil { - return err - } - err = r.mu.Acquire(ctx, 1) + if totalRequests > 0 { + // TODO: channel? + respWg.Wait() + err := r.mu.Acquire(ctx, 1) if err != nil { return err } @@ -190,10 +236,8 @@ func (s *Server) serveBatch(ctx context.Context, return err } for i, v := range batchResults { - var a any - a = v.payload err = r.send(ctx, &callEnv{ - v: &a, + v: v.payload, err: v.err, id: v.msg.ID, extrafields: v.msg.ExtraFields, @@ -216,17 +260,18 @@ func (s *Server) serveBatch(ctx context.Context, return err } } else if totalRequests == 0 { - err = r.mu.Acquire(ctx, 1) + // all notification, so immediately flush + err := r.mu.Acquire(ctx, 1) if err != nil { return err } defer r.mu.Release(1) - err := r.remote.Flush() + err = r.remote.Flush() if err != nil { return err } } - wg.Wait() + returnWg.Wait() return nil } @@ -239,7 +284,7 @@ type callResponder struct { } type callEnv struct { - v *any + v any err error id *codec.ID extrafields codec.ExtraFields @@ -274,9 +319,13 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { // if there is no error, we try to marshal the result e.Field("result", func(e *jx.Encoder) { if env.v != nil { - switch cast := (*env.v).(type) { + switch cast := (env.v).(type) { case json.RawMessage: - e.Raw(cast) + if len(cast) == 0 { + e.Null() + } else { + e.Raw(cast) + } default: err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) { eo.DisableNewline = true