diff --git a/contrib/codecs/http/client.go b/contrib/codecs/http/client.go index e18167207355ad50faa21111dafea35770594abd..2f01aa93798a961877848ac33415bef57b8fb4de 100644 --- a/contrib/codecs/http/client.go +++ b/contrib/codecs/http/client.go @@ -88,6 +88,7 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any) } msg := clientutil.GetMessage() defer clientutil.PutMessage(msg) + err = json.NewDecoder(resp.Body).Decode(&msg) if err != nil { return fmt.Errorf("decode json: %w", err) diff --git a/contrib/codecs/http/codec.go b/contrib/codecs/http/codec.go index e1568746f313e430fec00068d14dbd35eea66f0d..caa126e40886f0479ab7149899d0c24d43d02914 100644 --- a/contrib/codecs/http/codec.go +++ b/contrib/codecs/http/codec.go @@ -214,21 +214,22 @@ func (c *Codec) ReadBatch(ctx context.Context) ([]*codec.Message, bool, error) { } // closes the connection -func (c *Codec) Close() error { - c.cn() - return nil +func (c *Codec) Write(p []byte) (n int, err error) { + return c.wr.Write(p) } -func (c *Codec) Send(fn func(e io.Writer) error) error { - if err := fn(c.w); err != nil { +func (c *Codec) Flush() error { + defer c.cn() + err := c.wr.Flush() + if err != nil { return err } return nil } -func (c *Codec) Flush() error { - defer c.cn() - return c.wr.Flush() +func (c *Codec) Close() error { + c.cn() + return nil } // Closed returns a channel which is closed when the connection is closed. diff --git a/contrib/codecs/http/handler.go b/contrib/codecs/http/handler.go index b77009283cee5d7046c9c14d931321f457152642..6c2cd253a33b1c8937d366008cc02d1439530b14 100644 --- a/contrib/codecs/http/handler.go +++ b/contrib/codecs/http/handler.go @@ -26,15 +26,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "no server set", http.StatusInternalServerError) return } - c := codecPool.Get().(*Codec) - c.Reset(w, r) + c := NewCodec(w, r) w.Header().Set("content-type", contentType) err := s.Server.ServeCodec(r.Context(), c) if err != nil { // slog.Error("codec err", "err", err) } - go func() { - <-c.Closed() - codecPool.Put(c) - }() + <-c.Closed() } diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go index efb706c43633fe2200d1c66d10e1c0c817fc5437..bbbc4b3a299d16038c8dabf45d8590b191e9a155 100644 --- a/contrib/codecs/websocket/codec.go +++ b/contrib/codecs/websocket/codec.go @@ -3,6 +3,7 @@ package websocket import ( "context" "io" + "log" "net/http" "sync" "time" @@ -10,10 +11,18 @@ import ( "gfx.cafe/open/websocket" "github.com/goccy/go-json" + _ "net/http/pprof" + "gfx.cafe/open/jrpc/pkg/codec" "gfx.cafe/open/jrpc/pkg/serverutil" ) +func init() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() +} + type Codec struct { closed chan struct{} conn *websocket.Conn @@ -102,7 +111,6 @@ func (c *Codec) Write(p []byte) (n int, err error) { if err != nil { return 0, err } - c.currentFrame = wr } return c.currentFrame.Write(p) @@ -118,7 +126,12 @@ func (c *Codec) Flush() error { } return wr.Close() } - return c.currentFrame.Close() + err := c.currentFrame.Close() + if err != nil { + return err + } + c.currentFrame = nil + return nil } func (c *Codec) PeerInfo() codec.PeerInfo { diff --git a/pkg/server/] b/pkg/server/] deleted file mode 100644 index 381125251def0108e7d73c619aa3f7df139a3dc2..0000000000000000000000000000000000000000 --- a/pkg/server/] +++ /dev/null @@ -1,315 +0,0 @@ -package server - -import ( - "context" - "errors" - "sync" - - "gfx.cafe/open/jrpc/pkg/codec" - "golang.org/x/sync/semaphore" - - "gfx.cafe/util/go/bufpool" - - "github.com/go-faster/jx" - "github.com/goccy/go-json" -) - -// Server is an RPC server. -// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics. -// a server can be used to listenandserve multiple codecs at a time -type Server struct { - services codec.Handler - - lctx context.Context - cn context.CancelFunc -} - -// NewServer creates a new server instance with no registered handlers. -func NewServer(r codec.Handler) *Server { - server := &Server{services: r} - server.lctx, server.cn = context.WithCancel(context.Background()) - return server -} - -// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes -// the response back using the given codec. It will block until the codec is closed -func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) error { - defer remote.Close() - - batchMu := semaphore.NewWeighted(1) - // add a cancel to the context so we can cancel all the child tasks on return - ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo())) - defer cn() - - allErrs := []error{} - var mu sync.Mutex - wg := sync.WaitGroup{} - err := func() error { - for { - // read messages from the stream synchronously - incoming, batch, err := remote.ReadBatch(ctx) - if err != nil { - return err - } - wg.Add(1) - go func() { - defer wg.Done() - - responder := &callResponder{ - remote: remote, - batchMu: batchMu, - batch: batch, - } - err = s.serveBatch(ctx, incoming, responder) - if err != nil { - mu.Lock() - defer mu.Unlock() - allErrs = append(allErrs, err) - } - }() - } - }() - allErrs = append(allErrs, err) - if len(allErrs) > 0 { - return errors.Join(allErrs...) - } - return nil -} - -func (s *Server) Shutdown(ctx context.Context) { - s.cn() -} - -func (s *Server) serveBatch(ctx context.Context, - incoming []*codec.Message, - r *callResponder, -) error { - // check for empty batch - if r.batch && len(incoming) == 0 { - // if it is empty batch, send the empty batch error and immediately return - err := r.send(ctx, &callEnv{ - pkt: &codec.Message{ - ID: codec.NewNullIDPtr(), - Error: codec.NewInvalidRequestError("empty batch"), - }, - }) - if err != nil { - return err - } - } - - rs := []*callRespWriter{} - - totalRequests := 0 - // populate the envelope we are about to send. this is synchronous pre-prpcessing - for _, v := range incoming { - // create the response writer - rw := &callRespWriter{} - rs = append(rs, rw) - // a nil incoming message means an empty response - if v == nil { - rw.msg = &codec.Message{ID: codec.NewNullIDPtr()} - continue - } - rw.msg = v - if v.ID != nil { - totalRequests += 1 - } - } - var doneMu *semaphore.Weighted - doneMu = semaphore.NewWeighted(int64(totalRequests)) - - if totalRequests == 0 { - err := r.remote.Flush() - if err != nil { - return err - } - } - - // create a waitgroup for everything - wg := sync.WaitGroup{} - wg.Add(len(rs)) - // for each item in the envelope - peerInfo := r.remote.PeerInfo() - isBatchWithRequests := totalRequests > 1 && !r.batch - for _, vRef := range rs { - v := vRef - if isBatchWithRequests { - v.noStream = isBatchWithRequests - v.doneMu = doneMu - } - // now process each request in its own goroutine - // TODO: stress test this. - go func() { - defer wg.Done() - // early respond to nil requests - if v.msg == nil || len(v.msg.Method) == 0 { - v.msg.Error = codec.NewInvalidRequestError("invalid request") - return - } - req := codec.NewRequestFromMessage( - ctx, - v.msg, - ) - req.Peer = peerInfo - s.services.ServeRPC(v, req) - }() - } - // we only need to do this if this is a batch call with requests - if isBatchWithRequests { - // first we need to wait for every single request to be completed - err := doneMu.Acquire(ctx, int64(totalRequests)) - if err != nil { - return err - } - // now write the prefix - _, err = r.remote.Write([]byte{'['}) - if err != nil { - return err - } - // release them, one by one - for i := 0; i < totalRequests; i++ { - // release one - canCh <- struct{}{} - // wait for finish - <-doneCh - // write the comma or ] - char := ',' - if i == totalRequests-1 { - char = ']' - } - _, err = r.remote.Write([]byte{byte(char)}) - if err != nil { - return err - } - } - } - wg.Wait() - return nil -} - -type callResponder struct { - remote codec.ReaderWriter - mu *semaphore.Weighted - batchMu *semaphore.Weighted - - batch bool - batchStarted bool -} - -type callEnv struct { - v any - err error - pkt *codec.Message - id *codec.ID - extrafields codec.ExtraFields -} - -func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { - err = c.mu.Acquire(ctx, 1) - if err != nil { - return err - } - defer c.mu.Release(1) - // notification gets nothing - // if all msgs in batch are notification, we trigger an allSkip and write nothing - //if c.batch { - // allSkip := true - // for _, v := range env.responses { - // if v.skip != true { - // allSkip = false - // } - // } - // if allSkip { - // return c.remote.Send(func(e *jx.Encoder) error { return nil }) - // } - //} - // create the streaming encoder - enc := jx.GetEncoder() - enc.ResetWriter(c.remote) - enc.Obj(func(e *jx.Encoder) { - e.Field("jsonrpc", func(e *jx.Encoder) { - e.Str("2.0") - }) - if env.id != nil { - e.Field("id", func(e *jx.Encoder) { - e.Raw(env.id.RawMessage()) - }) - } - if env.extrafields != nil { - for k, v := range env.extrafields { - e.Field(k, func(e *jx.Encoder) { - e.Raw(v) - }) - } - } - if env.err != nil { - e.Field("error", func(e *jx.Encoder) { - codec.EncodeError(e, env.err) - }) - } else { - // if there is no error, we try to marshal the result - e.Field("result", func(e *jx.Encoder) { - if env.v != nil { - switch cast := env.v.(type) { - case json.RawMessage: - e.Raw(cast) - default: - err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) { - eo.DisableNewline = true - }) - if err != nil { - return - } - } - } else { - e.Null() - } - }) - } - }) - // a json encoding error here is possibly fatal.... - if err != nil { - return err - } - return enc.Close() -} - -type notifyEnv struct { - method string - dat any - extra codec.ExtraFields -} - -func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error { - err := c.mu.Acquire(ctx, 1) - if err != nil { - return err - } - defer c.mu.Release(1) - err = c.batchMu.Acquire(ctx, 1) - if err != nil { - return err - } - defer c.batchMu.Release(1) - msg := &codec.Message{} - // allocate a temp buffer for this packet - buf := bufpool.GetStd() - defer bufpool.PutStd(buf) - err = json.NewEncoder(buf).Encode(env.dat) - if err != nil { - msg.Error = err - } else { - msg.Params = buf.Bytes() - } - msg.ExtraFields = env.extra - // add the method - msg.Method = env.method - enc := jx.GetEncoder() - enc.ResetWriter(c.remote) - err = codec.MarshalMessage(msg, enc) - if err != nil { - return err - } - return enc.Close() -} diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go index 321e9a395cd6016a0607a17c05a8cf0aea94d918..4a31120d8b2d46f8bc0d3dcb88d651cfb2388cf2 100644 --- a/pkg/server/responsewriter.go +++ b/pkg/server/responsewriter.go @@ -2,8 +2,10 @@ package server import ( "context" + "log" "net/http" "sync" + "time" "gfx.cafe/open/jrpc/pkg/codec" "github.com/goccy/go-json" @@ -41,15 +43,16 @@ func (c *callRespWriter) Send(v any, e error) (err error) { } c.sendCalled = true // defer the sending of this for later - defer c.doneMu.Release(1) + if c.doneMu != nil { + defer c.doneMu.Release(1) + } // batch requests are not individually streamed. // the reason is beacuse i couldn't think of a good way to implement it // ultimately they need to be buffered. there's some optimistic multiplexing you can // do, but that felt really complicated and not worth the time. if c.noStream { - if e == nil { + if e != nil { c.err = e - return nil } if v != nil { // json marshaling errors are reported to the handler @@ -57,20 +60,25 @@ func (c *callRespWriter) Send(v any, e error) (err error) { if err != nil { return err } - return nil } + return nil } + s := time.Now() + log.Println("try") err = c.cr.mu.Acquire(c.ctx, 1) if err != nil { return err } + log.Println("release", time.Since(s)) + s2 := time.Now() defer c.cr.mu.Release(1) err = c.cr.send(c.ctx, &callEnv{ - v: v, + v: &v, err: e, id: c.msg.ID, extrafields: c.msg.ExtraFields, }) + log.Println("release", time.Since(s2)) err = c.cr.remote.Flush() if err != nil { return err @@ -90,7 +98,12 @@ func (c *callRespWriter) Header() http.Header { } func (c *callRespWriter) Notify(method string, v any) error { - err := c.cr.notify(c.ctx, ¬ifyEnv{ + err := c.cr.mu.Acquire(c.ctx, 1) + if err != nil { + return err + } + defer c.cr.mu.Release(1) + err = c.cr.notify(c.ctx, ¬ifyEnv{ method: method, dat: v, extra: c.msg.ExtraFields, diff --git a/pkg/server/server.go b/pkg/server/server.go index 3ae29cbf1c70d6db2a6bed9e0c360491ad45c7f6..8ca1430a2a11d6a18f8a5c0dad6029a6047dea3a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -61,7 +61,6 @@ func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) erro } err = s.serveBatch(ctx, incoming, responder) if err != nil { - // remote.Flush() mu.Lock() defer mu.Unlock() allErrs = append(allErrs, err) @@ -69,6 +68,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) erro }() } }() + wg.Wait() allErrs = append(allErrs, err) if len(allErrs) > 0 { return errors.Join(allErrs...) @@ -111,10 +111,12 @@ func (s *Server) serveBatch(ctx context.Context, rs = append(rs, rw) // a nil incoming message means an empty response if v == nil { - rw.msg = &codec.Message{ID: codec.NewNullIDPtr()} - continue + v = &codec.Message{ID: codec.NewNullIDPtr()} } rw.msg = v + if len(v.Method) == 0 { + rw.err = codec.NewInvalidRequestError("invalid request") + } if v.ID != nil { totalRequests += 1 } @@ -131,24 +133,27 @@ func (s *Server) serveBatch(ctx context.Context, wg.Add(len(rs)) // for each item in the envelope peerInfo := r.remote.PeerInfo() - isBatchWithRequests := totalRequests > 1 && !r.batch batchResults := []*callRespWriter{} for _, vRef := range rs { v := vRef - v.doneMu = doneMu - if isBatchWithRequests { + if r.batch { v.noStream = true - batchResults = append(batchResults, v) + if v.msg.ID != nil { + v.doneMu = doneMu + batchResults = append(batchResults, v) + } + } + // early respond to nil requests + if v.err != nil { + v.sendCalled = true + v.doneMu.Release(1) + wg.Done() + continue } // now process each request in its own goroutine // TODO: stress test this. go func() { defer wg.Done() - // early respond to nil requests - if v.msg == nil || len(v.msg.Method) == 0 { - v.msg.Error = codec.NewInvalidRequestError("invalid request") - return - } req := codec.NewRequestFromMessage( ctx, v.msg, @@ -157,13 +162,13 @@ func (s *Server) serveBatch(ctx context.Context, s.services.ServeRPC(v, req) }() } - // we only need to do this if this is a batch call with requests - // first we need to wait for every single request to be completed - err = doneMu.Acquire(ctx, int64(totalRequests)) - if err != nil { - return err - } - if isBatchWithRequests { + + if r.batch { + // we only need to do this if this is a batch call with requests + err = doneMu.Acquire(ctx, int64(totalRequests)) + if err != nil { + return err + } err = r.mu.Acquire(ctx, 1) if err != nil { return err @@ -175,8 +180,10 @@ func (s *Server) serveBatch(ctx context.Context, return err } for i, v := range batchResults { + var a any + a = v.payload err = r.send(ctx, &callEnv{ - v: v.payload, + v: &a, err: v.err, id: v.msg.ID, extrafields: v.msg.ExtraFields, @@ -198,6 +205,16 @@ func (s *Server) serveBatch(ctx context.Context, if err != nil { return err } + } else if totalRequests == 0 { + err = r.mu.Acquire(ctx, 1) + if err != nil { + return err + } + defer r.mu.Release(1) + err := r.remote.Flush() + if err != nil { + return err + } } wg.Wait() return nil @@ -212,7 +229,7 @@ type callResponder struct { } type callEnv struct { - v any + v *any err error pkt *codec.Message id *codec.ID @@ -248,7 +265,7 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { // if there is no error, we try to marshal the result e.Field("result", func(e *jx.Encoder) { if env.v != nil { - switch cast := env.v.(type) { + switch cast := (*env.v).(type) { case json.RawMessage: e.Raw(cast) default: