diff --git a/contrib/codecs/broker/broker.go b/contrib/codecs/_broker/broker.go
similarity index 77%
rename from contrib/codecs/broker/broker.go
rename to contrib/codecs/_broker/broker.go
index b234fb4cd6ac3763d93ddc8fc470d19c277fdf93..01dfb5eeb761010f1144f9458f32df777180febc 100644
--- a/contrib/codecs/broker/broker.go
+++ b/contrib/codecs/_broker/broker.go
@@ -3,8 +3,7 @@ package broker
import (
"context"
"encoding/json"
-
- "github.com/go-faster/jx"
+ "io"
)
type ServerSpoke interface {
@@ -22,12 +21,12 @@ type Broker interface {
}
type Replier interface {
- Send(fn func(*jx.Encoder) error) error
+ Send(fn func(io.Writer) error) error
}
-type ReplierFunc func(fn func(*jx.Encoder) error) error
+type ReplierFunc func(fn func(io.Writer) error) error
-func (r ReplierFunc) Send(fn func(*jx.Encoder) error) error {
+func (r ReplierFunc) Send(fn func(io.Writer) error) error {
return r(fn)
}
diff --git a/contrib/codecs/broker/broker_inproc.go b/contrib/codecs/_broker/broker_inproc.go
similarity index 100%
rename from contrib/codecs/broker/broker_inproc.go
rename to contrib/codecs/_broker/broker_inproc.go
diff --git a/contrib/codecs/broker/client.go b/contrib/codecs/_broker/client.go
similarity index 100%
rename from contrib/codecs/broker/client.go
rename to contrib/codecs/_broker/client.go
diff --git a/contrib/codecs/broker/codec.go b/contrib/codecs/_broker/codec.go
similarity index 87%
rename from contrib/codecs/broker/codec.go
rename to contrib/codecs/_broker/codec.go
index 96542ca9f2aafaf6bcf04c88a44bc61244445779..adac8b6ee10592e8207f3ef8d6c3d5732f59ec4c 100644
--- a/contrib/codecs/broker/codec.go
+++ b/contrib/codecs/_broker/codec.go
@@ -8,7 +8,7 @@ import (
"gfx.cafe/open/jrpc/pkg/codec"
"gfx.cafe/open/jrpc/pkg/serverutil"
- "github.com/go-faster/jx"
+ "github.com/gogo/protobuf/io"
)
var _ codec.ReaderWriter = (*Codec)(nil)
@@ -68,7 +68,11 @@ func (c *Codec) Close() error {
return nil
}
-func (c *Codec) Send(fn func(e *jx.Encoder) error) error {
+func (c *Codec) Send(fn func(io.Writer) error) error {
+ return c.replier.Send(fn)
+}
+
+func (c *Codec) Flush() error {
return c.replier.Send(fn)
}
@@ -76,8 +80,3 @@ func (c *Codec) Send(fn func(e *jx.Encoder) error) error {
func (c *Codec) Closed() <-chan struct{} {
return c.closeCh
}
-
-// RemoteAddr returns the peer address of the connection.
-func (c *Codec) RemoteAddr() string {
- return ""
-}
diff --git a/contrib/codecs/broker/codec_test.go b/contrib/codecs/_broker/codec_test.go
similarity index 100%
rename from contrib/codecs/broker/codec_test.go
rename to contrib/codecs/_broker/codec_test.go
diff --git a/contrib/codecs/broker/server.go b/contrib/codecs/_broker/server.go
similarity index 100%
rename from contrib/codecs/broker/server.go
rename to contrib/codecs/_broker/server.go
diff --git a/contrib/codecs/http/codec.go b/contrib/codecs/http/codec.go
index cedd65140fbd1df4c5b29794e9368c3fe1b31b69..e1568746f313e430fec00068d14dbd35eea66f0d 100644
--- a/contrib/codecs/http/codec.go
+++ b/contrib/codecs/http/codec.go
@@ -1,6 +1,7 @@
package http
import (
+ "bufio"
"context"
"encoding/base64"
"errors"
@@ -13,7 +14,6 @@ import (
"gfx.cafe/open/jrpc/pkg/codec"
"gfx.cafe/open/jrpc/pkg/serverutil"
- "github.com/go-faster/jx"
)
var _ codec.ReaderWriter = (*Codec)(nil)
@@ -25,8 +25,7 @@ type Codec struct {
r *http.Request
w http.ResponseWriter
- wr io.Writer
- jx *jx.Encoder
+ wr *bufio.Writer
msgs chan *serverutil.Bundle
errCh chan httpError
@@ -45,15 +44,14 @@ func NewCodec(w http.ResponseWriter, r *http.Request) *Codec {
}
func (c *Codec) Reset(w http.ResponseWriter, r *http.Request) {
- c.wr = w
+ c.wr = bufio.NewWriter(w)
if w == nil {
- c.wr = io.Discard
+ c.wr = bufio.NewWriter(io.Discard)
}
c.r = r
c.w = w
c.msgs = make(chan *serverutil.Bundle, 1)
c.errCh = make(chan httpError, 1)
- c.jx = jx.NewStreamingEncoder(w, 4096)
ctx := c.r.Context()
c.ctx, c.cn = context.WithCancel(ctx)
@@ -221,13 +219,16 @@ func (c *Codec) Close() error {
return nil
}
-func (c *Codec) Send(fn func(e *jx.Encoder) error) error {
- defer c.cn()
- defer c.jx.ResetWriter(c.wr)
- if err := fn(c.jx); err != nil {
+func (c *Codec) Send(fn func(e io.Writer) error) error {
+ if err := fn(c.w); err != nil {
return err
}
- return c.jx.Close()
+ return nil
+}
+
+func (c *Codec) Flush() error {
+ defer c.cn()
+ return c.wr.Flush()
}
// Closed returns a channel which is closed when the connection is closed.
diff --git a/contrib/codecs/rdwr/codec.go b/contrib/codecs/rdwr/codec.go
index a071009f12822f0349aa080a9127140fabf0d483..c50d0cc12d372769c51aed475496857588c48600 100644
--- a/contrib/codecs/rdwr/codec.go
+++ b/contrib/codecs/rdwr/codec.go
@@ -6,7 +6,6 @@ import (
"io"
"sync"
- "github.com/go-faster/jx"
"github.com/goccy/go-json"
"gfx.cafe/open/jrpc/pkg/codec"
@@ -17,10 +16,8 @@ type Codec struct {
ctx context.Context
cn func()
- rd io.Reader
- wrLock sync.Mutex
- wr *bufio.Writer
- jx *jx.Encoder
+ rd io.Reader
+ wr *bufio.Writer
dec *json.Decoder
decBuf json.RawMessage
@@ -36,7 +33,6 @@ func NewCodec(rd io.Reader, wr io.Writer) *Codec {
wr: bufio.NewWriter(wr),
dec: json.NewDecoder(rd),
}
- c.jx = jx.NewStreamingEncoder(wr, 4096)
return c
}
@@ -74,23 +70,16 @@ func (c *Codec) Close() error {
return nil
}
-func (c *Codec) Send(fn func(e *jx.Encoder) error) error {
- c.wrLock.Lock()
- defer c.wrLock.Unlock()
- defer c.jx.ResetWriter(c.wr)
- if err := fn(c.jx); err != nil {
- return err
- }
- if err := c.jx.Close(); err != nil {
- return err
- }
+func (c *Codec) Write(p []byte) (n int, err error) {
+ n, err = c.wr.Write(p)
+ return n, err
+}
+
+func (c *Codec) Flush() error {
if _, err := c.wr.Write([]byte("\n")); err != nil {
return err
}
- if err := c.wr.Flush(); err != nil {
- return err
- }
- return nil
+ return c.wr.Flush()
}
// Closed returns a channel which is closed when the connection is closed.
diff --git a/contrib/codecs/rdwr/rdwr_test.go b/contrib/codecs/rdwr/rdwr_test.go
index a2ee61ddc0104a12f4120139487a8379754bd917..db0e2fa7c0dc9c2316d8a8e8dde1608c91f8dc82 100644
--- a/contrib/codecs/rdwr/rdwr_test.go
+++ b/contrib/codecs/rdwr/rdwr_test.go
@@ -9,6 +9,7 @@ import (
"gfx.cafe/open/jrpc/contrib/jmux"
"gfx.cafe/open/jrpc/pkg/server"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -24,7 +25,8 @@ func TestRDWRSetup(t *testing.T) {
clientCodec := rdwr.NewCodec(rd_s, wr_c)
client := rdwr.NewClient(rd_c, wr_s)
go func() {
- srv.ServeCodec(ctx, clientCodec)
+ err := srv.ServeCodec(ctx, clientCodec)
+ assert.NoError(t, err)
}()
var res any
diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go
index 017e06d54caa466f7f8da8a9175b9f5394d633a8..efb706c43633fe2200d1c66d10e1c0c817fc5437 100644
--- a/contrib/codecs/websocket/codec.go
+++ b/contrib/codecs/websocket/codec.go
@@ -8,7 +8,6 @@ import (
"time"
"gfx.cafe/open/websocket"
- "github.com/go-faster/jx"
"github.com/goccy/go-json"
"gfx.cafe/open/jrpc/pkg/codec"
@@ -19,8 +18,8 @@ type Codec struct {
closed chan struct{}
conn *websocket.Conn
- jx *jx.Encoder
- wrLock sync.Mutex
+ currentFrame io.WriteCloser
+ wrLock sync.Mutex
decBuf json.RawMessage
decLock sync.Mutex
@@ -33,7 +32,6 @@ func newWebsocketCodec(ctx context.Context, conn *websocket.Conn, host string, r
c := &Codec{
closed: make(chan struct{}),
conn: conn,
- jx: jx.NewStreamingEncoder(nil, 4096),
}
c.i.Transport = "ws"
// Fill in connection details.
@@ -96,22 +94,31 @@ func (c *Codec) ReadBatch(ctx context.Context) ([]*codec.Message, bool, error) {
return ans.Messages, ans.Batch, nil
}
-func (c *Codec) Send(fn func(e *jx.Encoder) error) error {
+func (c *Codec) Write(p []byte) (n int, err error) {
c.wrLock.Lock()
defer c.wrLock.Unlock()
+ if c.currentFrame == nil {
+ wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
+ if err != nil {
+ return 0, err
+ }
- wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
- if err != nil {
- return err
+ c.currentFrame = wr
}
- c.jx.ResetWriter(wr)
- if err = fn(c.jx); err != nil {
- return err
- }
- if err = c.jx.Close(); err != nil {
- return err
+ return c.currentFrame.Write(p)
+}
+
+func (c *Codec) Flush() error {
+ c.wrLock.Lock()
+ defer c.wrLock.Unlock()
+ if c.currentFrame == nil {
+ wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
+ if err != nil {
+ return err
+ }
+ return wr.Close()
}
- return wr.Close()
+ return c.currentFrame.Close()
}
func (c *Codec) PeerInfo() codec.PeerInfo {
diff --git a/go.mod b/go.mod
index 22eef0318323097b2a9f2beab78a22f3813e4c4a..c912ba093165b9aaba37ae0303e4d8248d2800eb 100644
--- a/go.mod
+++ b/go.mod
@@ -11,11 +11,11 @@ require (
gfx.cafe/util/go/bufpool v0.0.0-20230721185457-c559e86c829c
gfx.cafe/util/go/frand v0.0.0-20230721185457-c559e86c829c
gfx.cafe/util/go/generic v0.0.0-20230721185457-c559e86c829c
- github.com/alecthomas/kong v0.8.0
github.com/go-faster/jx v1.1.0
github.com/goccy/go-json v0.10.2
github.com/rs/xid v1.5.0
github.com/stretchr/testify v1.8.4
+ golang.org/x/sync v0.4.0
sigs.k8s.io/yaml v1.3.0
)
diff --git a/go.sum b/go.sum
index 342eda528d97db9d91c704217ec6562b74f2b08d..ef1477ba7ccd35f66b491e65ce0f5795afde6040 100644
--- a/go.sum
+++ b/go.sum
@@ -8,12 +8,6 @@ gfx.cafe/util/go/generic v0.0.0-20230721185457-c559e86c829c h1:alCfDKmPC0EC0KGlZ
gfx.cafe/util/go/generic v0.0.0-20230721185457-c559e86c829c/go.mod h1:WvSX4JsCRBuIXj0FRBFX9YLg+2SoL3w8Ww19uZO9yNE=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY=
github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA=
-github.com/alecthomas/assert/v2 v2.1.0 h1:tbredtNcQnoSd3QBhQWI7QZ3XHOVkw1Moklp2ojoH/0=
-github.com/alecthomas/assert/v2 v2.1.0/go.mod h1:b/+1DI2Q6NckYi+3mXyH3wFb8qG37K/DuK80n7WefXA=
-github.com/alecthomas/kong v0.8.0 h1:ryDCzutfIqJPnNn0omnrgHLbAggDQM2VWHikE1xqK7s=
-github.com/alecthomas/kong v0.8.0/go.mod h1:n1iCIO2xS46oE8ZfYCNDqdR0b0wZNrXAIAqro/2132U=
-github.com/alecthomas/repr v0.1.0 h1:ENn2e1+J3k09gyj2shc0dHr/yjaWSHRlrJ4DPMevDqE=
-github.com/alecthomas/repr v0.1.0/go.mod h1:2kn6fqh/zIyPLmm3ugklbEi5hg5wS435eygvNfaDQL8=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -26,8 +20,6 @@ github.com/go-faster/errors v0.6.1/go.mod h1:5MGV2/2T9yvlrbhe9pD9LO5Z/2zCSq2T8j+
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
-github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
-github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM=
github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
@@ -51,6 +43,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/exp v0.0.0-20230206171751-46f607a40771 h1:xP7rWLUr1e1n2xkK5YB4LI0hPEy3LJC6Wk+D4pGlOJg=
golang.org/x/exp v0.0.0-20230206171751-46f607a40771/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
+golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
+golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
diff --git a/pkg/codec/errors.go b/pkg/codec/errors.go
index a8c2f0f6ebef6bf7683778e92f208a8c3915f62d..b9d1aad201964a61e6c870c27da1ee4938b45a17 100644
--- a/pkg/codec/errors.go
+++ b/pkg/codec/errors.go
@@ -26,7 +26,9 @@ const (
)
var (
- ErrIllegalExtraField = errors.New("invalid extra field")
+ ErrIllegalExtraField = errors.New("invalid extra field")
+ ErrSendAlreadyCalled = errors.New("send already called")
+ ErrCantSendNotification = errors.New("can't send to a notification")
)
// Error wraps RPC errors, which contain an error code in addition to the message.
diff --git a/pkg/codec/json.go b/pkg/codec/json.go
index ff2e1cdcf71ffce399ab6cdbf3565ff11d3049af..77f2c8a5a6ba72369de50d76cc8d3368e0dbe1d3 100644
--- a/pkg/codec/json.go
+++ b/pkg/codec/json.go
@@ -15,11 +15,7 @@ func NewNull() json.RawMessage {
return json.RawMessage("null")
}
-// RequestField is an idea borrowed from sourcegraphs implementation.
-type RequestField struct {
- Name string
- Value json.RawMessage
-}
+type ExtraFields map[string]json.RawMessage
// A value of this type can a JSON-RPC request, notification, successful response or
// error response. Which one it is depends on the fields.
@@ -30,7 +26,7 @@ type Message struct {
Result json.RawMessage `json:"result,omitempty"`
Error error `json:"error,omitempty"`
- ExtraFields []RequestField `json:"-"`
+ ExtraFields ExtraFields `json:"-"`
}
func MarshalMessage(m *Message, enc *jx.Encoder) error {
@@ -49,9 +45,9 @@ func MarshalMessage(m *Message, enc *jx.Encoder) error {
e.Str(m.Method)
})
}
- for _, v := range m.ExtraFields {
- e.Field(v.Name, func(e *jx.Encoder) {
- e.Raw(v.Value)
+ for k, v := range m.ExtraFields {
+ e.Field(k, func(e *jx.Encoder) {
+ e.Raw(v)
})
}
if m.Error != nil {
@@ -88,10 +84,7 @@ func UnmarshalMessage(m *Message, dec *jx.Decoder) error {
}
buf := bytes.NewBuffer(make(json.RawMessage, len(val)))
buf.Write(val)
- m.ExtraFields = append(m.ExtraFields, RequestField{
- Name: key,
- Value: buf.Bytes(),
- })
+ m.ExtraFields[key] = buf.Bytes()
case "jsonrpc":
value, err := d.Str()
if err != nil {
@@ -217,21 +210,30 @@ func IsBatchMessage(raw json.RawMessage) bool {
return false
}
-func (m *Message) SetExtraField(name string, v any) error {
+func (m ExtraFields) SetExtraField(name string, v any) (err error) {
switch name {
case "id", "jsonrpc", "method", "params", "result", "error":
return fmt.Errorf("%w: %q", ErrIllegalExtraField, name)
}
+ if v == nil {
+ delete(m, name)
+ }
val, err := json.Marshal(v)
if err != nil {
return err
}
- m.ExtraFields = append(m.ExtraFields, RequestField{
- Name: name,
- Value: val,
- })
+ m[name] = val
return nil
}
+func (m ExtraFields) Clear() {
+ for k := range m {
+ delete(m, k)
+ }
+}
+
+func (m *Message) SetExtraField(name string, v any) error {
+ return m.ExtraFields.SetExtraField(name, v)
+}
// parseMessage parses raw bytes as a (batch of) JSON-RPC message(s). There are no error
// checks in this function because the raw message has already been syntax-checked when it
diff --git a/pkg/codec/transport.go b/pkg/codec/transport.go
index 2a9543d03bd344c7c057302dac2003e79722bfca..8291413c05483178b29aec5fd952e1b26f69649f 100644
--- a/pkg/codec/transport.go
+++ b/pkg/codec/transport.go
@@ -2,10 +2,16 @@ package codec
import (
"context"
-
- "github.com/go-faster/jx"
+ "io"
+ "net"
)
+type Listener interface {
+ Accept() (ReaderWriter, error)
+ Close() error
+ Addr() net.Addr
+}
+
// ReaderWriter represents a single stream
// this stream can be used to send/receive an arbitrary amount of requests and notifications
type ReaderWriter interface {
@@ -18,7 +24,7 @@ type ReaderWriter interface {
type Reader interface {
// gets the peer info
PeerInfo() PeerInfo
- // json.RawMessage can be an array of requests. if it is, then it is a batch request
+ // reads a batch of messages
ReadBatch(ctx context.Context) (msgs []*Message, batch bool, err error)
// closes the connection
Close() error
@@ -28,9 +34,8 @@ type Reader interface {
// Implementations must be safe for concurrent use.
type Writer interface {
// write json blob to stream
- Send(fn func(e *jx.Encoder) error) error
+ io.Writer
+ Flush() error
// Closed returns a channel which is closed when the connection is closed.
Closed() <-chan struct{}
- // RemoteAddr returns the peer address of the connection.
- RemoteAddr() string
}
diff --git a/pkg/server/] b/pkg/server/]
new file mode 100644
index 0000000000000000000000000000000000000000..381125251def0108e7d73c619aa3f7df139a3dc2
--- /dev/null
+++ b/pkg/server/]
@@ -0,0 +1,315 @@
+package server
+
+import (
+ "context"
+ "errors"
+ "sync"
+
+ "gfx.cafe/open/jrpc/pkg/codec"
+ "golang.org/x/sync/semaphore"
+
+ "gfx.cafe/util/go/bufpool"
+
+ "github.com/go-faster/jx"
+ "github.com/goccy/go-json"
+)
+
+// Server is an RPC server.
+// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
+// a server can be used to listenandserve multiple codecs at a time
+type Server struct {
+ services codec.Handler
+
+ lctx context.Context
+ cn context.CancelFunc
+}
+
+// NewServer creates a new server instance with no registered handlers.
+func NewServer(r codec.Handler) *Server {
+ server := &Server{services: r}
+ server.lctx, server.cn = context.WithCancel(context.Background())
+ return server
+}
+
+// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
+// the response back using the given codec. It will block until the codec is closed
+func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) error {
+ defer remote.Close()
+
+ batchMu := semaphore.NewWeighted(1)
+ // add a cancel to the context so we can cancel all the child tasks on return
+ ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
+ defer cn()
+
+ allErrs := []error{}
+ var mu sync.Mutex
+ wg := sync.WaitGroup{}
+ err := func() error {
+ for {
+ // read messages from the stream synchronously
+ incoming, batch, err := remote.ReadBatch(ctx)
+ if err != nil {
+ return err
+ }
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ responder := &callResponder{
+ remote: remote,
+ batchMu: batchMu,
+ batch: batch,
+ }
+ err = s.serveBatch(ctx, incoming, responder)
+ if err != nil {
+ mu.Lock()
+ defer mu.Unlock()
+ allErrs = append(allErrs, err)
+ }
+ }()
+ }
+ }()
+ allErrs = append(allErrs, err)
+ if len(allErrs) > 0 {
+ return errors.Join(allErrs...)
+ }
+ return nil
+}
+
+func (s *Server) Shutdown(ctx context.Context) {
+ s.cn()
+}
+
+func (s *Server) serveBatch(ctx context.Context,
+ incoming []*codec.Message,
+ r *callResponder,
+) error {
+ // check for empty batch
+ if r.batch && len(incoming) == 0 {
+ // if it is empty batch, send the empty batch error and immediately return
+ err := r.send(ctx, &callEnv{
+ pkt: &codec.Message{
+ ID: codec.NewNullIDPtr(),
+ Error: codec.NewInvalidRequestError("empty batch"),
+ },
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ rs := []*callRespWriter{}
+
+ totalRequests := 0
+ // populate the envelope we are about to send. this is synchronous pre-prpcessing
+ for _, v := range incoming {
+ // create the response writer
+ rw := &callRespWriter{}
+ rs = append(rs, rw)
+ // a nil incoming message means an empty response
+ if v == nil {
+ rw.msg = &codec.Message{ID: codec.NewNullIDPtr()}
+ continue
+ }
+ rw.msg = v
+ if v.ID != nil {
+ totalRequests += 1
+ }
+ }
+ var doneMu *semaphore.Weighted
+ doneMu = semaphore.NewWeighted(int64(totalRequests))
+
+ if totalRequests == 0 {
+ err := r.remote.Flush()
+ if err != nil {
+ return err
+ }
+ }
+
+ // create a waitgroup for everything
+ wg := sync.WaitGroup{}
+ wg.Add(len(rs))
+ // for each item in the envelope
+ peerInfo := r.remote.PeerInfo()
+ isBatchWithRequests := totalRequests > 1 && !r.batch
+ for _, vRef := range rs {
+ v := vRef
+ if isBatchWithRequests {
+ v.noStream = isBatchWithRequests
+ v.doneMu = doneMu
+ }
+ // now process each request in its own goroutine
+ // TODO: stress test this.
+ go func() {
+ defer wg.Done()
+ // early respond to nil requests
+ if v.msg == nil || len(v.msg.Method) == 0 {
+ v.msg.Error = codec.NewInvalidRequestError("invalid request")
+ return
+ }
+ req := codec.NewRequestFromMessage(
+ ctx,
+ v.msg,
+ )
+ req.Peer = peerInfo
+ s.services.ServeRPC(v, req)
+ }()
+ }
+ // we only need to do this if this is a batch call with requests
+ if isBatchWithRequests {
+ // first we need to wait for every single request to be completed
+ err := doneMu.Acquire(ctx, int64(totalRequests))
+ if err != nil {
+ return err
+ }
+ // now write the prefix
+ _, err = r.remote.Write([]byte{'['})
+ if err != nil {
+ return err
+ }
+ // release them, one by one
+ for i := 0; i < totalRequests; i++ {
+ // release one
+ canCh <- struct{}{}
+ // wait for finish
+ <-doneCh
+ // write the comma or ]
+ char := ','
+ if i == totalRequests-1 {
+ char = ']'
+ }
+ _, err = r.remote.Write([]byte{byte(char)})
+ if err != nil {
+ return err
+ }
+ }
+ }
+ wg.Wait()
+ return nil
+}
+
+type callResponder struct {
+ remote codec.ReaderWriter
+ mu *semaphore.Weighted
+ batchMu *semaphore.Weighted
+
+ batch bool
+ batchStarted bool
+}
+
+type callEnv struct {
+ v any
+ err error
+ pkt *codec.Message
+ id *codec.ID
+ extrafields codec.ExtraFields
+}
+
+func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
+ err = c.mu.Acquire(ctx, 1)
+ if err != nil {
+ return err
+ }
+ defer c.mu.Release(1)
+ // notification gets nothing
+ // if all msgs in batch are notification, we trigger an allSkip and write nothing
+ //if c.batch {
+ // allSkip := true
+ // for _, v := range env.responses {
+ // if v.skip != true {
+ // allSkip = false
+ // }
+ // }
+ // if allSkip {
+ // return c.remote.Send(func(e *jx.Encoder) error { return nil })
+ // }
+ //}
+ // create the streaming encoder
+ enc := jx.GetEncoder()
+ enc.ResetWriter(c.remote)
+ enc.Obj(func(e *jx.Encoder) {
+ e.Field("jsonrpc", func(e *jx.Encoder) {
+ e.Str("2.0")
+ })
+ if env.id != nil {
+ e.Field("id", func(e *jx.Encoder) {
+ e.Raw(env.id.RawMessage())
+ })
+ }
+ if env.extrafields != nil {
+ for k, v := range env.extrafields {
+ e.Field(k, func(e *jx.Encoder) {
+ e.Raw(v)
+ })
+ }
+ }
+ if env.err != nil {
+ e.Field("error", func(e *jx.Encoder) {
+ codec.EncodeError(e, env.err)
+ })
+ } else {
+ // if there is no error, we try to marshal the result
+ e.Field("result", func(e *jx.Encoder) {
+ if env.v != nil {
+ switch cast := env.v.(type) {
+ case json.RawMessage:
+ e.Raw(cast)
+ default:
+ err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
+ eo.DisableNewline = true
+ })
+ if err != nil {
+ return
+ }
+ }
+ } else {
+ e.Null()
+ }
+ })
+ }
+ })
+ // a json encoding error here is possibly fatal....
+ if err != nil {
+ return err
+ }
+ return enc.Close()
+}
+
+type notifyEnv struct {
+ method string
+ dat any
+ extra codec.ExtraFields
+}
+
+func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error {
+ err := c.mu.Acquire(ctx, 1)
+ if err != nil {
+ return err
+ }
+ defer c.mu.Release(1)
+ err = c.batchMu.Acquire(ctx, 1)
+ if err != nil {
+ return err
+ }
+ defer c.batchMu.Release(1)
+ msg := &codec.Message{}
+ // allocate a temp buffer for this packet
+ buf := bufpool.GetStd()
+ defer bufpool.PutStd(buf)
+ err = json.NewEncoder(buf).Encode(env.dat)
+ if err != nil {
+ msg.Error = err
+ } else {
+ msg.Params = buf.Bytes()
+ }
+ msg.ExtraFields = env.extra
+ // add the method
+ msg.Method = env.method
+ enc := jx.GetEncoder()
+ enc.ResetWriter(c.remote)
+ err = codec.MarshalMessage(msg, enc)
+ if err != nil {
+ return err
+ }
+ return enc.Close()
+}
diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go
index 79b688205791b5e91de06b103df4584e12cb2a09..321e9a395cd6016a0607a17c05a8cf0aea94d918 100644
--- a/pkg/server/responsewriter.go
+++ b/pkg/server/responsewriter.go
@@ -1,36 +1,87 @@
package server
import (
+ "context"
"net/http"
+ "sync"
"gfx.cafe/open/jrpc/pkg/codec"
+ "github.com/goccy/go-json"
+ "golang.org/x/sync/semaphore"
)
var _ codec.ResponseWriter = (*callRespWriter)(nil)
+// callRespWriter is NOT thread safe
type callRespWriter struct {
+ cr *callResponder
msg *codec.Message
+ ctx context.Context
- pkt *codec.Message
+ noStream bool
+ doneMu *semaphore.Weighted
- dat any
- skip bool
- header http.Header
+ payload json.RawMessage
+ err error
- notifications func(env *notifyEnv) error
+ sendCalled bool
+ header http.Header
+
+ mu sync.Mutex
}
-func (c *callRespWriter) Send(v any, err error) error {
+func (c *callRespWriter) Send(v any, e error) (err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.msg.ID == nil {
+ return codec.ErrCantSendNotification
+ }
+ if c.sendCalled {
+ return codec.ErrSendAlreadyCalled
+ }
+ c.sendCalled = true
+ // defer the sending of this for later
+ defer c.doneMu.Release(1)
+ // batch requests are not individually streamed.
+ // the reason is beacuse i couldn't think of a good way to implement it
+ // ultimately they need to be buffered. there's some optimistic multiplexing you can
+ // do, but that felt really complicated and not worth the time.
+ if c.noStream {
+ if e == nil {
+ c.err = e
+ return nil
+ }
+ if v != nil {
+ // json marshaling errors are reported to the handler
+ c.payload, err = json.Marshal(v)
+ if err != nil {
+ return err
+ }
+ return nil
+ }
+ }
+ err = c.cr.mu.Acquire(c.ctx, 1)
+ if err != nil {
+ return err
+ }
+ defer c.cr.mu.Release(1)
+ err = c.cr.send(c.ctx, &callEnv{
+ v: v,
+ err: e,
+ id: c.msg.ID,
+ extrafields: c.msg.ExtraFields,
+ })
+ err = c.cr.remote.Flush()
if err != nil {
- c.pkt.Error = err
- return nil
+ return err
}
- c.dat = v
return nil
}
func (c *callRespWriter) SetExtraField(k string, v any) error {
- c.pkt.SetExtraField(k, v)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.msg.SetExtraField(k, v)
return nil
}
@@ -39,9 +90,17 @@ func (c *callRespWriter) Header() http.Header {
}
func (c *callRespWriter) Notify(method string, v any) error {
- return c.notifications(¬ifyEnv{
+ err := c.cr.notify(c.ctx, ¬ifyEnv{
method: method,
dat: v,
- extra: c.pkt.ExtraFields,
+ extra: c.msg.ExtraFields,
})
+ if err != nil {
+ return err
+ }
+ err = c.cr.remote.Flush()
+ if err != nil {
+ return err
+ }
+ return nil
}
diff --git a/pkg/server/server.go b/pkg/server/server.go
index e6df14654b39163c08a82d725d10eb7e336168a0..3ae29cbf1c70d6db2a6bed9e0c360491ad45c7f6 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -2,10 +2,11 @@ package server
import (
"context"
- "log/slog"
+ "errors"
"sync"
"gfx.cafe/open/jrpc/pkg/codec"
+ "golang.org/x/sync/semaphore"
"gfx.cafe/util/go/bufpool"
@@ -18,11 +19,15 @@ import (
// a server can be used to listenandserve multiple codecs at a time
type Server struct {
services codec.Handler
+
+ lctx context.Context
+ cn context.CancelFunc
}
// NewServer creates a new server instance with no registered handlers.
func NewServer(r codec.Handler) *Server {
server := &Server{services: r}
+ server.lctx, server.cn = context.WithCancel(context.Background())
return server
}
@@ -30,250 +35,274 @@ func NewServer(r codec.Handler) *Server {
// the response back using the given codec. It will block until the codec is closed
func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) error {
defer remote.Close()
- responder := &callResponder{
- remote: remote,
- }
+
+ sema := semaphore.NewWeighted(1)
// add a cancel to the context so we can cancel all the child tasks on return
ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
defer cn()
- errch := make(chan error)
- go func() {
+ allErrs := []error{}
+ var mu sync.Mutex
+ wg := sync.WaitGroup{}
+ err := func() error {
for {
// read messages from the stream synchronously
incoming, batch, err := remote.ReadBatch(ctx)
if err != nil {
- select {
- case errch <- err:
- case <-ctx.Done():
- }
- return
+ return err
}
+ wg.Add(1)
go func() {
- err = s.serveBatch(ctx, incoming, batch, remote, responder)
+ defer wg.Done()
+ responder := &callResponder{
+ remote: remote,
+ batch: batch,
+ mu: sema,
+ }
+ err = s.serveBatch(ctx, incoming, responder)
if err != nil {
- select {
- case errch <- err:
- case <-ctx.Done():
- }
- return
+ // remote.Flush()
+ mu.Lock()
+ defer mu.Unlock()
+ allErrs = append(allErrs, err)
}
}()
}
}()
- // exit on either the first error, or the context closing.
- select {
- case <-ctx.Done():
- return nil
- case err := <-errch:
- return err
+ allErrs = append(allErrs, err)
+ if len(allErrs) > 0 {
+ return errors.Join(allErrs...)
}
+ return nil
+}
+
+func (s *Server) Shutdown(ctx context.Context) {
+ s.cn()
}
func (s *Server) serveBatch(ctx context.Context,
incoming []*codec.Message,
- batch bool,
- remote codec.ReaderWriter, responder *callResponder) error {
- env := &callEnv{
- batch: batch,
- }
-
+ r *callResponder,
+) error {
// check for empty batch
- if batch && len(incoming) == 0 {
+ if r.batch && len(incoming) == 0 {
// if it is empty batch, send the empty batch error and immediately return
- return responder.send(ctx, &callEnv{
- responses: []*callRespWriter{{
- pkt: &codec.Message{
- ID: codec.NewNullIDPtr(),
- Error: codec.NewInvalidRequestError("empty batch"),
- },
- }},
- batch: false,
+ err := r.send(ctx, &callEnv{
+ pkt: &codec.Message{
+ ID: codec.NewNullIDPtr(),
+ Error: codec.NewInvalidRequestError("empty batch"),
+ },
})
+ if err != nil {
+ return err
+ }
}
+ rs := []*callRespWriter{}
+
+ totalRequests := 0
// populate the envelope we are about to send. this is synchronous pre-prpcessing
for _, v := range incoming {
// create the response writer
rw := &callRespWriter{
- notifications: func(env *notifyEnv) error { return responder.notify(ctx, env) },
- header: remote.PeerInfo().HTTP.Headers,
+ ctx: ctx,
+ cr: r,
}
- env.responses = append(env.responses, rw)
+ rs = append(rs, rw)
// a nil incoming message means an empty response
if v == nil {
rw.msg = &codec.Message{ID: codec.NewNullIDPtr()}
- rw.pkt = &codec.Message{ID: codec.NewNullIDPtr()}
continue
}
rw.msg = v
- if v.ID == nil {
- rw.pkt = &codec.Message{ID: codec.NewNullIDPtr()}
- continue
+ if v.ID != nil {
+ totalRequests += 1
}
- rw.pkt = &codec.Message{ID: v.ID}
+ }
+ var doneMu *semaphore.Weighted
+ doneMu = semaphore.NewWeighted(int64(totalRequests))
+ err := doneMu.Acquire(ctx, int64(totalRequests))
+ if err != nil {
+ return err
}
- // create a waitgroup
+ // create a waitgroup for everything
wg := sync.WaitGroup{}
- wg.Add(len(env.responses))
+ wg.Add(len(rs))
// for each item in the envelope
- peerInfo := remote.PeerInfo()
- for _, vRef := range env.responses {
+ peerInfo := r.remote.PeerInfo()
+ isBatchWithRequests := totalRequests > 1 && !r.batch
+ batchResults := []*callRespWriter{}
+ for _, vRef := range rs {
v := vRef
- // process each request in its own goroutine
+ v.doneMu = doneMu
+ if isBatchWithRequests {
+ v.noStream = true
+ batchResults = append(batchResults, v)
+ }
+ // now process each request in its own goroutine
+ // TODO: stress test this.
go func() {
defer wg.Done()
// early respond to nil requests
if v.msg == nil || len(v.msg.Method) == 0 {
- v.pkt.Error = codec.NewInvalidRequestError("invalid request")
- return
- }
- if v.msg.ID == nil || v.msg.ID.IsNull() {
- // it's a notification, so we mark skip and we don't write anything for it
- v.skip = true
+ v.msg.Error = codec.NewInvalidRequestError("invalid request")
return
}
- r := codec.NewRequestFromMessage(
+ req := codec.NewRequestFromMessage(
ctx,
v.msg,
)
- r.Peer = peerInfo
- s.services.ServeRPC(v, r)
+ req.Peer = peerInfo
+ s.services.ServeRPC(v, req)
}()
}
- wg.Wait()
- return responder.send(ctx, env)
-}
-
-type callResponder struct {
- remote codec.ReaderWriter
- mu sync.Mutex
-}
-
-type notifyEnv struct {
- method string
- dat any
- extra []codec.RequestField
-}
-
-func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error {
- err := c.remote.Send(func(e *jx.Encoder) error {
- msg := &codec.Message{}
- var err error
- // allocate a temp buffer for this packet
- buf := bufpool.GetStd()
- defer bufpool.PutStd(buf)
- err = json.NewEncoder(buf).Encode(env.dat)
+ // we only need to do this if this is a batch call with requests
+ // first we need to wait for every single request to be completed
+ err = doneMu.Acquire(ctx, int64(totalRequests))
+ if err != nil {
+ return err
+ }
+ if isBatchWithRequests {
+ err = r.mu.Acquire(ctx, 1)
if err != nil {
- msg.Error = err
- } else {
- msg.Params = buf.Bytes()
+ return err
}
- msg.ExtraFields = env.extra
- // add the method
- msg.Method = env.method
- err = codec.MarshalMessage(msg, e)
+ defer r.mu.Release(1)
+ // write them, one by one
+ _, err = r.remote.Write([]byte{'['})
+ if err != nil {
+ return err
+ }
+ for i, v := range batchResults {
+ err = r.send(ctx, &callEnv{
+ v: v.payload,
+ err: v.err,
+ id: v.msg.ID,
+ extrafields: v.msg.ExtraFields,
+ })
+ if err != nil {
+ return err
+ }
+ // write the comma or ]
+ char := ','
+ if i == len(batchResults)-1 {
+ char = ']'
+ }
+ _, err = r.remote.Write([]byte{byte(char)})
+ if err != nil {
+ return err
+ }
+ }
+ err = r.remote.Flush()
if err != nil {
return err
}
- return nil
- })
- if err != nil {
- return err
}
+ wg.Wait()
return nil
}
+type callResponder struct {
+ remote codec.ReaderWriter
+ mu *semaphore.Weighted
+
+ batch bool
+ batchStarted bool
+}
+
type callEnv struct {
- responses []*callRespWriter
- batch bool
+ v any
+ err error
+ pkt *codec.Message
+ id *codec.ID
+ extrafields codec.ExtraFields
}
func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
- // notification gets nothing
- // if all msgs in batch are notification, we trigger an allSkip and write nothing
- if env.batch {
- allSkip := true
- for _, v := range env.responses {
- if v.skip != true {
- allSkip = false
- }
- }
- if allSkip {
- return c.remote.Send(func(e *jx.Encoder) error { return nil })
- }
- }
- // create the streaming encoder
- err = c.remote.Send(func(enc *jx.Encoder) error {
- if env.batch {
- enc.ArrStart()
+ enc := jx.GetEncoder()
+ defer jx.PutEncoder(enc)
+ enc.Grow(4096)
+ enc.ResetWriter(c.remote)
+ enc.Obj(func(e *jx.Encoder) {
+ e.Field("jsonrpc", func(e *jx.Encoder) {
+ e.Str("2.0")
+ })
+ if env.id != nil {
+ e.Field("id", func(e *jx.Encoder) {
+ e.Raw(env.id.RawMessage())
+ })
}
- for _, v := range env.responses {
- msg := v.pkt
- // if we are a batch AND we are supposed to skip, then continue
- // this means that for a non-batch notification, we do not skip! this is to ensure we get always a "response" for http-like endpoints
- if env.batch && v.skip {
- continue
- }
- m := msg
- enc.Obj(func(e *jx.Encoder) {
- e.Field("jsonrpc", func(e *jx.Encoder) {
- e.Str("2.0")
+ if env.extrafields != nil {
+ for k, v := range env.extrafields {
+ e.Field(k, func(e *jx.Encoder) {
+ e.Raw(v)
})
- if m.ID != nil {
- e.Field("id", func(e *jx.Encoder) {
- e.Raw(m.ID.RawMessage())
- })
- }
- if m.Method != "" {
- e.Field("method", func(e *jx.Encoder) {
- e.Str(m.Method)
- })
- }
- for _, v := range m.ExtraFields {
- e.Field(v.Name, func(e *jx.Encoder) {
- e.Raw(v.Value)
- })
- }
- if m.Error != nil {
- e.Field("error", func(e *jx.Encoder) {
- codec.EncodeError(e, m.Error)
- })
- } else {
- // if there is no error, we try to marshal the result
- e.Field("result", func(e *jx.Encoder) {
- if v.dat != nil {
- switch c := v.dat.(type) {
- case json.RawMessage:
- e.Raw(c)
- default:
- err = json.NewEncoder(e).EncodeWithOption(v.dat, func(eo *json.EncodeOption) {
- eo.DisableNewline = true
- })
- if err != nil {
- return
- }
- }
- } else {
- e.Null()
+ }
+ }
+ if env.err != nil {
+ e.Field("error", func(e *jx.Encoder) {
+ codec.EncodeError(e, env.err)
+ })
+ } else {
+ // if there is no error, we try to marshal the result
+ e.Field("result", func(e *jx.Encoder) {
+ if env.v != nil {
+ switch cast := env.v.(type) {
+ case json.RawMessage:
+ e.Raw(cast)
+ default:
+ err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
+ eo.DisableNewline = true
+ })
+ if err != nil {
+ return
}
- })
+ }
+ } else {
+ e.Null()
}
})
- // a json encoding error here is possibly fatal....
- if err != nil {
- slog.Error("codec json encoding err", "err", err)
- return err
- }
- }
- if env.batch {
- enc.ArrEnd()
}
- return nil
})
+ // a json encoding error here is possibly fatal....
+ if err != nil {
+ return err
+ }
+ err = enc.Close()
if err != nil {
return err
}
return nil
}
+
+type notifyEnv struct {
+ method string
+ dat any
+ extra codec.ExtraFields
+}
+
+func (c *callResponder) notify(ctx context.Context, env *notifyEnv) (err error) {
+ msg := &codec.Message{}
+ // allocate a temp buffer for this packet
+ buf := bufpool.GetStd()
+ defer bufpool.PutStd(buf)
+ err = json.NewEncoder(buf).Encode(env.dat)
+ if err != nil {
+ msg.Error = err
+ } else {
+ msg.Params = buf.Bytes()
+ }
+ msg.ExtraFields = env.extra
+ // add the method
+ msg.Method = env.method
+ enc := jx.GetEncoder()
+ defer jx.PutEncoder(enc)
+ enc.Grow(4096)
+ enc.ResetWriter(c.remote)
+ err = codec.MarshalMessage(msg, enc)
+ if err != nil {
+ return err
+ }
+ return enc.Close()
+}