Code owners
Assign users and groups as approvers for specific file changes. Learn more.
server.go 8.05 KiB
package server
import (
"context"
"errors"
"sync"
"golang.org/x/sync/semaphore"
"gfx.cafe/open/jrpc/pkg/jsonrpc"
"gfx.cafe/util/go/bufpool"
"github.com/go-faster/jx"
"github.com/goccy/go-json"
)
// Server is an RPC server.
// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
// a server can be used to listenandserve multiple codecs at a time
type Server struct {
services jsonrpc.Handler
lctx context.Context
cn context.CancelFunc
}
// NewServer creates a new server instance with no registered handlers.
func NewServer(r jsonrpc.Handler) *Server {
server := &Server{services: r}
server.lctx, server.cn = context.WithCancel(context.Background())
return server
}
// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed
func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error {
defer remote.Close()
sema := semaphore.NewWeighted(1)
// add a cancel to the context so we can cancel all the child tasks on return
ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
defer cn()
allErrs := []error{}
var mu sync.Mutex
wg := sync.WaitGroup{}
err := func() error {
for {
// read messages from the stream synchronously
incoming, batch, err := remote.ReadBatch(ctx)
if err != nil {
return err
}
wg.Add(1)
go func() {
defer wg.Done()
responder := &callResponder{
remote: remote,
batch: batch,
mu: sema,
}
err = s.serve(ctx, incoming, responder)
if err != nil {
mu.Lock()
defer mu.Unlock()
allErrs = append(allErrs, err)
}
}()
}
}()
wg.Wait()
allErrs = append(allErrs, err)
if len(allErrs) > 0 {
return errors.Join(allErrs...)
}
return nil
}
func (s *Server) Shutdown(ctx context.Context) {
s.cn()
}
func (s *Server) serve(ctx context.Context,
incoming []*jsonrpc.Message,
r *callResponder,
) error {
if r.batch {
return s.serveBatch(ctx, incoming, r)
} else {
return s.serveSingle(ctx, incoming[0], r)
}
}
func (s *Server) serveSingle(ctx context.Context,
incoming *jsonrpc.Message,
r *callResponder,
) error {
rw := &streamingRespWriter{
ctx: ctx,
cr: r,
}
rw.msg, rw.err = produceOutputMessage(incoming)
req := jsonrpc.NewRequestFromMessage(
ctx,
rw.msg,
)
req.Peer = r.remote.PeerInfo()
if rw.msg.ID == nil {
// all notification, so immediately flush
err := r.mu.Acquire(ctx, 1)
if err != nil {
return err
}
defer r.mu.Release(1)
err = r.remote.Flush()
if err != nil {
return err
}
}
s.services.ServeRPC(rw, req)
if rw.sendCalled == false && rw.msg.ID != nil {
rw.Send(jsonrpc.Null, nil)
}
return nil
}
func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message, err error) {
// a nil incoming message means return an invalid request.
if inputMessage == nil {
inputMessage = &jsonrpc.Message{ID: jsonrpc.NewNullIDPtr()}
err = jsonrpc.NewInvalidRequestError("invalid request")
}
out = inputMessage
out.Error = nil
// zero length method is always invalid request
if len(out.Method) == 0 {
// assume if the method is not there AND the id is not there that it's an invalid REQUEST not notification
// this makes sure we add 1 to totalRequests
if out.ID == nil {
out.ID = jsonrpc.NewNullIDPtr()
}
err = jsonrpc.NewInvalidRequestError("invalid request")
}
return
}
func (s *Server) serveBatch(ctx context.Context,
incoming []*jsonrpc.Message,
r *callResponder,
) error {
// check for empty batch
if r.batch && len(incoming) == 0 {
// if it is empty batch, send the empty batch error and immediately return
err := r.mu.Acquire(ctx, 1)
if err != nil {
return err
}
defer r.mu.Release(1)
err = r.send(ctx, &callEnv{
id: jsonrpc.NewNullIDPtr(),
err: jsonrpc.NewInvalidRequestError("empty batch"),
})
if err != nil {
return err
}
err = r.remote.Flush()
if err != nil {
return err
}
return nil
}
rs := []*batchingRespWriter{}
totalRequests := 0
// populate the envelope we are about to send. this is synchronous pre-prpcessing
for _, v := range incoming {
// create the response writer
rw := &batchingRespWriter{
ctx: ctx,
cr: r,
}
rs = append(rs, rw)
rw.msg, rw.err = produceOutputMessage(v)
// requests and malformed requests both count as requests
if rw.msg.ID != nil {
totalRequests += 1
}
}
// create a waitgroup for when every handler returns
returnWg := sync.WaitGroup{}
returnWg.Add(len(rs))
// for each item in the envelope
peerInfo := r.remote.PeerInfo()
batchResults := []*batchingRespWriter{}
respWg := &sync.WaitGroup{}
respWg.Add(totalRequests)
for _, vRef := range rs {
v := vRef
if v.msg.ID != nil {
v.wg = respWg
batchResults = append(batchResults, v)
}
// now process each request in its own goroutine
// TODO: stress test this.
go func() {
defer returnWg.Done()
req := jsonrpc.NewRequestFromMessage(
ctx,
v.msg,
)
req.Peer = peerInfo
s.services.ServeRPC(v, req)
if v.sendCalled == false && v.err == nil {
v.Send(jsonrpc.Null, nil)
}
}()
}
if totalRequests > 0 {
// TODO: channel?
respWg.Wait()
err := r.mu.Acquire(ctx, 1)
if err != nil {
return err
}
defer r.mu.Release(1)
// write them, one by one
_, err = r.remote.Write([]byte{'['})
if err != nil {
return err
}
for i, v := range batchResults {
err = r.send(ctx, &callEnv{
v: v.payload,
err: v.err,
id: v.msg.ID,
})
if err != nil {
return err
}
// write the comma or ]
char := ','
if i == len(batchResults)-1 {
char = ']'
}
_, err = r.remote.Write([]byte{byte(char)})
if err != nil {
return err
}
}
err = r.remote.Flush()
if err != nil {
return err
}
} else if totalRequests == 0 {
// all notification, so immediately flush
err := r.mu.Acquire(ctx, 1)
if err != nil {
return err
}
defer r.mu.Release(1)
err = r.remote.Flush()
if err != nil {
return err
}
}
returnWg.Wait()
return nil
}
type callResponder struct {
remote jsonrpc.ReaderWriter
mu *semaphore.Weighted
batch bool
batchStarted bool
}
type callEnv struct {
v any
err error
id *jsonrpc.ID
}
func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
enc := jx.GetEncoder()
defer jx.PutEncoder(enc)
enc.Grow(4096)
enc.ResetWriter(c.remote)
enc.Obj(func(e *jx.Encoder) {
e.Field("jsonrpc", func(e *jx.Encoder) {
e.Str("2.0")
})
if env.id != nil {
e.Field("id", func(e *jx.Encoder) {
e.Raw(env.id.RawMessage())
})
}
if env.err != nil {
e.Field("error", func(e *jx.Encoder) {
jsonrpc.EncodeError(e, env.err)
})
} else {
// if there is no error, we try to marshal the result
e.Field("result", func(e *jx.Encoder) {
if env.v != nil {
switch cast := (env.v).(type) {
case json.RawMessage:
if len(cast) == 0 {
e.Null()
} else {
e.Raw(cast)
}
case func(e *jx.Encoder) error:
err = cast(e)
default:
err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
eo.DisableNewline = true
})
}
} else {
e.Null()
}
})
}
if env.err == nil && err != nil {
e.Field("error", func(e *jx.Encoder) {
jsonrpc.EncodeError(e, err)
})
}
})
// a json encoding error here is possibly fatal....
err = enc.Close()
if err != nil {
return err
}
return nil
}
type notifyEnv struct {
method string
dat any
}
func (c *callResponder) notify(ctx context.Context, env *notifyEnv) (err error) {
msg := &jsonrpc.Message{}
// allocate a temp buffer for this packet
buf := bufpool.GetStd()
defer bufpool.PutStd(buf)
err = json.NewEncoder(buf).Encode(env.dat)
if err != nil {
msg.Error = err
} else {
msg.Params = buf.Bytes()
}
// add the method
msg.Method = env.method
enc := jx.GetEncoder()
defer jx.PutEncoder(enc)
enc.Grow(4096)
enc.ResetWriter(c.remote)
err = jsonrpc.MarshalMessage(msg, enc)
if err != nil {
return err
}
return enc.Close()
}