From 413679a7b5166df7e178be4629bb58d1224570e3 Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Sat, 28 Oct 2023 00:15:29 -0500
Subject: [PATCH] ok
---
contrib/codecs/http/client.go | 1 +
contrib/codecs/http/codec.go | 17 +-
contrib/codecs/http/handler.go | 8 +-
contrib/codecs/websocket/codec.go | 17 +-
pkg/server/] | 315 ------------------------------
pkg/server/responsewriter.go | 25 ++-
pkg/server/server.go | 61 +++---
7 files changed, 85 insertions(+), 359 deletions(-)
delete mode 100644 pkg/server/]
diff --git a/contrib/codecs/http/client.go b/contrib/codecs/http/client.go
index e181672..2f01aa9 100644
--- a/contrib/codecs/http/client.go
+++ b/contrib/codecs/http/client.go
@@ -88,6 +88,7 @@ func (c *Client) Do(ctx context.Context, result any, method string, params any)
}
msg := clientutil.GetMessage()
defer clientutil.PutMessage(msg)
+
err = json.NewDecoder(resp.Body).Decode(&msg)
if err != nil {
return fmt.Errorf("decode json: %w", err)
diff --git a/contrib/codecs/http/codec.go b/contrib/codecs/http/codec.go
index e156874..caa126e 100644
--- a/contrib/codecs/http/codec.go
+++ b/contrib/codecs/http/codec.go
@@ -214,21 +214,22 @@ func (c *Codec) ReadBatch(ctx context.Context) ([]*codec.Message, bool, error) {
}
// closes the connection
-func (c *Codec) Close() error {
- c.cn()
- return nil
+func (c *Codec) Write(p []byte) (n int, err error) {
+ return c.wr.Write(p)
}
-func (c *Codec) Send(fn func(e io.Writer) error) error {
- if err := fn(c.w); err != nil {
+func (c *Codec) Flush() error {
+ defer c.cn()
+ err := c.wr.Flush()
+ if err != nil {
return err
}
return nil
}
-func (c *Codec) Flush() error {
- defer c.cn()
- return c.wr.Flush()
+func (c *Codec) Close() error {
+ c.cn()
+ return nil
}
// Closed returns a channel which is closed when the connection is closed.
diff --git a/contrib/codecs/http/handler.go b/contrib/codecs/http/handler.go
index b770092..6c2cd25 100644
--- a/contrib/codecs/http/handler.go
+++ b/contrib/codecs/http/handler.go
@@ -26,15 +26,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "no server set", http.StatusInternalServerError)
return
}
- c := codecPool.Get().(*Codec)
- c.Reset(w, r)
+ c := NewCodec(w, r)
w.Header().Set("content-type", contentType)
err := s.Server.ServeCodec(r.Context(), c)
if err != nil {
// slog.Error("codec err", "err", err)
}
- go func() {
- <-c.Closed()
- codecPool.Put(c)
- }()
+ <-c.Closed()
}
diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go
index efb706c..bbbc4b3 100644
--- a/contrib/codecs/websocket/codec.go
+++ b/contrib/codecs/websocket/codec.go
@@ -3,6 +3,7 @@ package websocket
import (
"context"
"io"
+ "log"
"net/http"
"sync"
"time"
@@ -10,10 +11,18 @@ import (
"gfx.cafe/open/websocket"
"github.com/goccy/go-json"
+ _ "net/http/pprof"
+
"gfx.cafe/open/jrpc/pkg/codec"
"gfx.cafe/open/jrpc/pkg/serverutil"
)
+func init() {
+ go func() {
+ log.Println(http.ListenAndServe("localhost:6060", nil))
+ }()
+}
+
type Codec struct {
closed chan struct{}
conn *websocket.Conn
@@ -102,7 +111,6 @@ func (c *Codec) Write(p []byte) (n int, err error) {
if err != nil {
return 0, err
}
-
c.currentFrame = wr
}
return c.currentFrame.Write(p)
@@ -118,7 +126,12 @@ func (c *Codec) Flush() error {
}
return wr.Close()
}
- return c.currentFrame.Close()
+ err := c.currentFrame.Close()
+ if err != nil {
+ return err
+ }
+ c.currentFrame = nil
+ return nil
}
func (c *Codec) PeerInfo() codec.PeerInfo {
diff --git a/pkg/server/] b/pkg/server/]
deleted file mode 100644
index 3811252..0000000
--- a/pkg/server/]
+++ /dev/null
@@ -1,315 +0,0 @@
-package server
-
-import (
- "context"
- "errors"
- "sync"
-
- "gfx.cafe/open/jrpc/pkg/codec"
- "golang.org/x/sync/semaphore"
-
- "gfx.cafe/util/go/bufpool"
-
- "github.com/go-faster/jx"
- "github.com/goccy/go-json"
-)
-
-// Server is an RPC server.
-// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
-// a server can be used to listenandserve multiple codecs at a time
-type Server struct {
- services codec.Handler
-
- lctx context.Context
- cn context.CancelFunc
-}
-
-// NewServer creates a new server instance with no registered handlers.
-func NewServer(r codec.Handler) *Server {
- server := &Server{services: r}
- server.lctx, server.cn = context.WithCancel(context.Background())
- return server
-}
-
-// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
-// the response back using the given codec. It will block until the codec is closed
-func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) error {
- defer remote.Close()
-
- batchMu := semaphore.NewWeighted(1)
- // add a cancel to the context so we can cancel all the child tasks on return
- ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
- defer cn()
-
- allErrs := []error{}
- var mu sync.Mutex
- wg := sync.WaitGroup{}
- err := func() error {
- for {
- // read messages from the stream synchronously
- incoming, batch, err := remote.ReadBatch(ctx)
- if err != nil {
- return err
- }
- wg.Add(1)
- go func() {
- defer wg.Done()
-
- responder := &callResponder{
- remote: remote,
- batchMu: batchMu,
- batch: batch,
- }
- err = s.serveBatch(ctx, incoming, responder)
- if err != nil {
- mu.Lock()
- defer mu.Unlock()
- allErrs = append(allErrs, err)
- }
- }()
- }
- }()
- allErrs = append(allErrs, err)
- if len(allErrs) > 0 {
- return errors.Join(allErrs...)
- }
- return nil
-}
-
-func (s *Server) Shutdown(ctx context.Context) {
- s.cn()
-}
-
-func (s *Server) serveBatch(ctx context.Context,
- incoming []*codec.Message,
- r *callResponder,
-) error {
- // check for empty batch
- if r.batch && len(incoming) == 0 {
- // if it is empty batch, send the empty batch error and immediately return
- err := r.send(ctx, &callEnv{
- pkt: &codec.Message{
- ID: codec.NewNullIDPtr(),
- Error: codec.NewInvalidRequestError("empty batch"),
- },
- })
- if err != nil {
- return err
- }
- }
-
- rs := []*callRespWriter{}
-
- totalRequests := 0
- // populate the envelope we are about to send. this is synchronous pre-prpcessing
- for _, v := range incoming {
- // create the response writer
- rw := &callRespWriter{}
- rs = append(rs, rw)
- // a nil incoming message means an empty response
- if v == nil {
- rw.msg = &codec.Message{ID: codec.NewNullIDPtr()}
- continue
- }
- rw.msg = v
- if v.ID != nil {
- totalRequests += 1
- }
- }
- var doneMu *semaphore.Weighted
- doneMu = semaphore.NewWeighted(int64(totalRequests))
-
- if totalRequests == 0 {
- err := r.remote.Flush()
- if err != nil {
- return err
- }
- }
-
- // create a waitgroup for everything
- wg := sync.WaitGroup{}
- wg.Add(len(rs))
- // for each item in the envelope
- peerInfo := r.remote.PeerInfo()
- isBatchWithRequests := totalRequests > 1 && !r.batch
- for _, vRef := range rs {
- v := vRef
- if isBatchWithRequests {
- v.noStream = isBatchWithRequests
- v.doneMu = doneMu
- }
- // now process each request in its own goroutine
- // TODO: stress test this.
- go func() {
- defer wg.Done()
- // early respond to nil requests
- if v.msg == nil || len(v.msg.Method) == 0 {
- v.msg.Error = codec.NewInvalidRequestError("invalid request")
- return
- }
- req := codec.NewRequestFromMessage(
- ctx,
- v.msg,
- )
- req.Peer = peerInfo
- s.services.ServeRPC(v, req)
- }()
- }
- // we only need to do this if this is a batch call with requests
- if isBatchWithRequests {
- // first we need to wait for every single request to be completed
- err := doneMu.Acquire(ctx, int64(totalRequests))
- if err != nil {
- return err
- }
- // now write the prefix
- _, err = r.remote.Write([]byte{'['})
- if err != nil {
- return err
- }
- // release them, one by one
- for i := 0; i < totalRequests; i++ {
- // release one
- canCh <- struct{}{}
- // wait for finish
- <-doneCh
- // write the comma or ]
- char := ','
- if i == totalRequests-1 {
- char = ']'
- }
- _, err = r.remote.Write([]byte{byte(char)})
- if err != nil {
- return err
- }
- }
- }
- wg.Wait()
- return nil
-}
-
-type callResponder struct {
- remote codec.ReaderWriter
- mu *semaphore.Weighted
- batchMu *semaphore.Weighted
-
- batch bool
- batchStarted bool
-}
-
-type callEnv struct {
- v any
- err error
- pkt *codec.Message
- id *codec.ID
- extrafields codec.ExtraFields
-}
-
-func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
- err = c.mu.Acquire(ctx, 1)
- if err != nil {
- return err
- }
- defer c.mu.Release(1)
- // notification gets nothing
- // if all msgs in batch are notification, we trigger an allSkip and write nothing
- //if c.batch {
- // allSkip := true
- // for _, v := range env.responses {
- // if v.skip != true {
- // allSkip = false
- // }
- // }
- // if allSkip {
- // return c.remote.Send(func(e *jx.Encoder) error { return nil })
- // }
- //}
- // create the streaming encoder
- enc := jx.GetEncoder()
- enc.ResetWriter(c.remote)
- enc.Obj(func(e *jx.Encoder) {
- e.Field("jsonrpc", func(e *jx.Encoder) {
- e.Str("2.0")
- })
- if env.id != nil {
- e.Field("id", func(e *jx.Encoder) {
- e.Raw(env.id.RawMessage())
- })
- }
- if env.extrafields != nil {
- for k, v := range env.extrafields {
- e.Field(k, func(e *jx.Encoder) {
- e.Raw(v)
- })
- }
- }
- if env.err != nil {
- e.Field("error", func(e *jx.Encoder) {
- codec.EncodeError(e, env.err)
- })
- } else {
- // if there is no error, we try to marshal the result
- e.Field("result", func(e *jx.Encoder) {
- if env.v != nil {
- switch cast := env.v.(type) {
- case json.RawMessage:
- e.Raw(cast)
- default:
- err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
- eo.DisableNewline = true
- })
- if err != nil {
- return
- }
- }
- } else {
- e.Null()
- }
- })
- }
- })
- // a json encoding error here is possibly fatal....
- if err != nil {
- return err
- }
- return enc.Close()
-}
-
-type notifyEnv struct {
- method string
- dat any
- extra codec.ExtraFields
-}
-
-func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error {
- err := c.mu.Acquire(ctx, 1)
- if err != nil {
- return err
- }
- defer c.mu.Release(1)
- err = c.batchMu.Acquire(ctx, 1)
- if err != nil {
- return err
- }
- defer c.batchMu.Release(1)
- msg := &codec.Message{}
- // allocate a temp buffer for this packet
- buf := bufpool.GetStd()
- defer bufpool.PutStd(buf)
- err = json.NewEncoder(buf).Encode(env.dat)
- if err != nil {
- msg.Error = err
- } else {
- msg.Params = buf.Bytes()
- }
- msg.ExtraFields = env.extra
- // add the method
- msg.Method = env.method
- enc := jx.GetEncoder()
- enc.ResetWriter(c.remote)
- err = codec.MarshalMessage(msg, enc)
- if err != nil {
- return err
- }
- return enc.Close()
-}
diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go
index 321e9a3..4a31120 100644
--- a/pkg/server/responsewriter.go
+++ b/pkg/server/responsewriter.go
@@ -2,8 +2,10 @@ package server
import (
"context"
+ "log"
"net/http"
"sync"
+ "time"
"gfx.cafe/open/jrpc/pkg/codec"
"github.com/goccy/go-json"
@@ -41,15 +43,16 @@ func (c *callRespWriter) Send(v any, e error) (err error) {
}
c.sendCalled = true
// defer the sending of this for later
- defer c.doneMu.Release(1)
+ if c.doneMu != nil {
+ defer c.doneMu.Release(1)
+ }
// batch requests are not individually streamed.
// the reason is beacuse i couldn't think of a good way to implement it
// ultimately they need to be buffered. there's some optimistic multiplexing you can
// do, but that felt really complicated and not worth the time.
if c.noStream {
- if e == nil {
+ if e != nil {
c.err = e
- return nil
}
if v != nil {
// json marshaling errors are reported to the handler
@@ -57,20 +60,25 @@ func (c *callRespWriter) Send(v any, e error) (err error) {
if err != nil {
return err
}
- return nil
}
+ return nil
}
+ s := time.Now()
+ log.Println("try")
err = c.cr.mu.Acquire(c.ctx, 1)
if err != nil {
return err
}
+ log.Println("release", time.Since(s))
+ s2 := time.Now()
defer c.cr.mu.Release(1)
err = c.cr.send(c.ctx, &callEnv{
- v: v,
+ v: &v,
err: e,
id: c.msg.ID,
extrafields: c.msg.ExtraFields,
})
+ log.Println("release", time.Since(s2))
err = c.cr.remote.Flush()
if err != nil {
return err
@@ -90,7 +98,12 @@ func (c *callRespWriter) Header() http.Header {
}
func (c *callRespWriter) Notify(method string, v any) error {
- err := c.cr.notify(c.ctx, ¬ifyEnv{
+ err := c.cr.mu.Acquire(c.ctx, 1)
+ if err != nil {
+ return err
+ }
+ defer c.cr.mu.Release(1)
+ err = c.cr.notify(c.ctx, ¬ifyEnv{
method: method,
dat: v,
extra: c.msg.ExtraFields,
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 3ae29cb..8ca1430 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -61,7 +61,6 @@ func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) erro
}
err = s.serveBatch(ctx, incoming, responder)
if err != nil {
- // remote.Flush()
mu.Lock()
defer mu.Unlock()
allErrs = append(allErrs, err)
@@ -69,6 +68,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) erro
}()
}
}()
+ wg.Wait()
allErrs = append(allErrs, err)
if len(allErrs) > 0 {
return errors.Join(allErrs...)
@@ -111,10 +111,12 @@ func (s *Server) serveBatch(ctx context.Context,
rs = append(rs, rw)
// a nil incoming message means an empty response
if v == nil {
- rw.msg = &codec.Message{ID: codec.NewNullIDPtr()}
- continue
+ v = &codec.Message{ID: codec.NewNullIDPtr()}
}
rw.msg = v
+ if len(v.Method) == 0 {
+ rw.err = codec.NewInvalidRequestError("invalid request")
+ }
if v.ID != nil {
totalRequests += 1
}
@@ -131,24 +133,27 @@ func (s *Server) serveBatch(ctx context.Context,
wg.Add(len(rs))
// for each item in the envelope
peerInfo := r.remote.PeerInfo()
- isBatchWithRequests := totalRequests > 1 && !r.batch
batchResults := []*callRespWriter{}
for _, vRef := range rs {
v := vRef
- v.doneMu = doneMu
- if isBatchWithRequests {
+ if r.batch {
v.noStream = true
- batchResults = append(batchResults, v)
+ if v.msg.ID != nil {
+ v.doneMu = doneMu
+ batchResults = append(batchResults, v)
+ }
+ }
+ // early respond to nil requests
+ if v.err != nil {
+ v.sendCalled = true
+ v.doneMu.Release(1)
+ wg.Done()
+ continue
}
// now process each request in its own goroutine
// TODO: stress test this.
go func() {
defer wg.Done()
- // early respond to nil requests
- if v.msg == nil || len(v.msg.Method) == 0 {
- v.msg.Error = codec.NewInvalidRequestError("invalid request")
- return
- }
req := codec.NewRequestFromMessage(
ctx,
v.msg,
@@ -157,13 +162,13 @@ func (s *Server) serveBatch(ctx context.Context,
s.services.ServeRPC(v, req)
}()
}
- // we only need to do this if this is a batch call with requests
- // first we need to wait for every single request to be completed
- err = doneMu.Acquire(ctx, int64(totalRequests))
- if err != nil {
- return err
- }
- if isBatchWithRequests {
+
+ if r.batch {
+ // we only need to do this if this is a batch call with requests
+ err = doneMu.Acquire(ctx, int64(totalRequests))
+ if err != nil {
+ return err
+ }
err = r.mu.Acquire(ctx, 1)
if err != nil {
return err
@@ -175,8 +180,10 @@ func (s *Server) serveBatch(ctx context.Context,
return err
}
for i, v := range batchResults {
+ var a any
+ a = v.payload
err = r.send(ctx, &callEnv{
- v: v.payload,
+ v: &a,
err: v.err,
id: v.msg.ID,
extrafields: v.msg.ExtraFields,
@@ -198,6 +205,16 @@ func (s *Server) serveBatch(ctx context.Context,
if err != nil {
return err
}
+ } else if totalRequests == 0 {
+ err = r.mu.Acquire(ctx, 1)
+ if err != nil {
+ return err
+ }
+ defer r.mu.Release(1)
+ err := r.remote.Flush()
+ if err != nil {
+ return err
+ }
}
wg.Wait()
return nil
@@ -212,7 +229,7 @@ type callResponder struct {
}
type callEnv struct {
- v any
+ v *any
err error
pkt *codec.Message
id *codec.ID
@@ -248,7 +265,7 @@ func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
// if there is no error, we try to marshal the result
e.Field("result", func(e *jx.Encoder) {
if env.v != nil {
- switch cast := env.v.(type) {
+ switch cast := (*env.v).(type) {
case json.RawMessage:
e.Raw(cast)
default:
--
GitLab