good morning!!!!

Skip to content
Snippets Groups Projects
server.go 8.05 KiB
Newer Older
a's avatar
a committed
package server
a's avatar
rpc
a committed

import (
	"context"
a's avatar
ok  
a committed
	"errors"
a's avatar
wg  
a committed
	"sync"
a's avatar
rpc
a committed

a's avatar
ok  
a committed
	"golang.org/x/sync/semaphore"
a's avatar
a committed

Garet Halliday's avatar
Garet Halliday committed
	"gfx.cafe/open/jrpc/pkg/jsonrpc"

a's avatar
a committed
	"gfx.cafe/util/go/bufpool"

	"github.com/go-faster/jx"
	"github.com/goccy/go-json"
a's avatar
rpc
a committed
)

// Server is an RPC server.
a's avatar
a  
a committed
// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
// a server can be used to listenandserve multiple codecs at a time
a's avatar
rpc
a committed
type Server struct {
a's avatar
a committed
	services jsonrpc.Handler
a's avatar
ok  
a committed

	lctx context.Context
	cn   context.CancelFunc
a's avatar
rpc
a committed
}

// NewServer creates a new server instance with no registered handlers.
a's avatar
a committed
func NewServer(r jsonrpc.Handler) *Server {
	server := &Server{services: r}
a's avatar
ok  
a committed
	server.lctx, server.cn = context.WithCancel(context.Background())
a's avatar
rpc
a committed
	return server
}

// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed
a's avatar
a committed
func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error {
	defer remote.Close()
a's avatar
ok  
a committed

	sema := semaphore.NewWeighted(1)
	// add a cancel to the context so we can cancel all the child tasks on return
	ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
	defer cn()

a's avatar
ok  
a committed
	allErrs := []error{}
	var mu sync.Mutex
	wg := sync.WaitGroup{}
	err := func() error {
		for {
			// read messages from the stream synchronously
			incoming, batch, err := remote.ReadBatch(ctx)
			if err != nil {
a's avatar
ok  
a committed
				return err
a's avatar
ok  
a committed
			wg.Add(1)
			go func() {
a's avatar
ok  
a committed
				defer wg.Done()
				responder := &callResponder{
					remote: remote,
					batch:  batch,
					mu:     sema,
				}
a's avatar
a committed
				err = s.serve(ctx, incoming, responder)
				if err != nil {
a's avatar
ok  
a committed
					mu.Lock()
					defer mu.Unlock()
					allErrs = append(allErrs, err)
a's avatar
a committed
	wg.Wait()
a's avatar
ok  
a committed
	allErrs = append(allErrs, err)
	if len(allErrs) > 0 {
		return errors.Join(allErrs...)
a's avatar
ok  
a committed
	return nil
}

func (s *Server) Shutdown(ctx context.Context) {
	s.cn()
a's avatar
a committed
func (s *Server) serve(ctx context.Context,
a's avatar
a committed
	incoming []*jsonrpc.Message,
a's avatar
a committed
	r *callResponder,
) error {
	if r.batch {
		return s.serveBatch(ctx, incoming, r)
	} else {
		return s.serveSingle(ctx, incoming[0], r)
	}
}

func (s *Server) serveSingle(ctx context.Context,
a's avatar
a committed
	incoming *jsonrpc.Message,
a's avatar
a committed
	r *callResponder,
) error {
	rw := &streamingRespWriter{
		ctx: ctx,
		cr:  r,
	}
	rw.msg, rw.err = produceOutputMessage(incoming)
a's avatar
a committed
	req := jsonrpc.NewRequestFromMessage(
a's avatar
a committed
		ctx,
		rw.msg,
	)
	req.Peer = r.remote.PeerInfo()
	if rw.msg.ID == nil {
		// all notification, so immediately flush
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err = r.remote.Flush()
		if err != nil {
			return err
		}
	}
	s.services.ServeRPC(rw, req)
	if rw.sendCalled == false && rw.msg.ID != nil {
a's avatar
a committed
		rw.Send(jsonrpc.Null, nil)
a's avatar
a committed
	}
	return nil
}

a's avatar
a committed
func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message, err error) {
a's avatar
a committed
	// a nil incoming message means return an invalid request.
	if inputMessage == nil {
a's avatar
a committed
		inputMessage = &jsonrpc.Message{ID: jsonrpc.NewNullIDPtr()}
		err = jsonrpc.NewInvalidRequestError("invalid request")
a's avatar
a committed
	}
	out = inputMessage
	out.Error = nil
	// zero length method is always invalid request
	if len(out.Method) == 0 {
		// assume if the method is not there AND the id is not there that it's an invalid REQUEST not notification
		// this makes sure we add 1 to totalRequests
		if out.ID == nil {
a's avatar
a committed
			out.ID = jsonrpc.NewNullIDPtr()
a's avatar
a committed
		}
a's avatar
a committed
		err = jsonrpc.NewInvalidRequestError("invalid request")
a's avatar
a committed
	}

	return
}

a's avatar
a committed
func (s *Server) serveBatch(ctx context.Context,
a's avatar
a committed
	incoming []*jsonrpc.Message,
a's avatar
ok  
a committed
	r *callResponder,
) error {
a's avatar
a committed
	// check for empty batch
a's avatar
ok  
a committed
	if r.batch && len(incoming) == 0 {
a's avatar
a  
a committed
		// if it is empty batch, send the empty batch error and immediately return
a's avatar
a committed
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err = r.send(ctx, &callEnv{
a's avatar
a committed
			id:  jsonrpc.NewNullIDPtr(),
			err: jsonrpc.NewInvalidRequestError("empty batch"),
a's avatar
a  
a committed
		})
a's avatar
ok  
a committed
		if err != nil {
			return err
		}
a's avatar
a committed
		err = r.remote.Flush()
		if err != nil {
			return err
		}
		return nil
a's avatar
a committed
	}

a's avatar
a committed
	rs := []*batchingRespWriter{}
a's avatar
ok  
a committed

	totalRequests := 0
a's avatar
a  
a committed
	// populate the envelope we are about to send. this is synchronous pre-prpcessing
a's avatar
a committed
	for _, v := range incoming {
a's avatar
a  
a committed
		// create the response writer
a's avatar
a committed
		rw := &batchingRespWriter{
a's avatar
ok  
a committed
			ctx: ctx,
			cr:  r,
a's avatar
a committed
		}
a's avatar
ok  
a committed
		rs = append(rs, rw)
a's avatar
a committed
		rw.msg, rw.err = produceOutputMessage(v)
a's avatar
a committed
		// requests and malformed requests both count as requests
a's avatar
a committed
		if rw.msg.ID != nil {
a's avatar
ok  
a committed
			totalRequests += 1
a's avatar
a  
a committed
		}
a's avatar
ok  
a committed
	}
a's avatar
a committed

a's avatar
a committed
	// create a waitgroup for when every handler returns
	returnWg := sync.WaitGroup{}
	returnWg.Add(len(rs))
a's avatar
a  
a committed
	// for each item in the envelope
a's avatar
ok  
a committed
	peerInfo := r.remote.PeerInfo()
a's avatar
a committed
	batchResults := []*batchingRespWriter{}

	respWg := &sync.WaitGroup{}
	respWg.Add(totalRequests)

a's avatar
ok  
a committed
	for _, vRef := range rs {
a's avatar
a committed
		v := vRef
a's avatar
a committed
		if v.msg.ID != nil {
			v.wg = respWg
			batchResults = append(batchResults, v)
a's avatar
ok  
a committed
		}
a's avatar
ok  
a committed
		// now process each request in its own goroutine
		// TODO: stress test this.
a's avatar
a committed
		go func() {
a's avatar
a committed
			defer returnWg.Done()
a's avatar
a committed
			req := jsonrpc.NewRequestFromMessage(
a's avatar
a committed
				ctx,
a's avatar
a committed
				v.msg,
			)
a's avatar
ok  
a committed
			req.Peer = peerInfo
			s.services.ServeRPC(v, req)
a's avatar
a committed
			if v.sendCalled == false && v.err == nil {
a's avatar
a committed
				v.Send(jsonrpc.Null, nil)
a's avatar
a committed
			}
a's avatar
a committed
		}()
	}
a's avatar
a committed

a's avatar
a committed
	if totalRequests > 0 {
		// TODO: channel?
		respWg.Wait()
		err := r.mu.Acquire(ctx, 1)
a's avatar
jx  
a committed
		if err != nil {
a's avatar
ok  
a committed
			return err
a's avatar
jx  
a committed
		}
a's avatar
ok  
a committed
		defer r.mu.Release(1)
		// write them, one by one
		_, err = r.remote.Write([]byte{'['})
		if err != nil {
			return err
		}
		for i, v := range batchResults {
			err = r.send(ctx, &callEnv{
a's avatar
a committed
				v:   v.payload,
				err: v.err,
				id:  v.msg.ID,
a's avatar
ok  
a committed
			})
			if err != nil {
				return err
			}
			// write the comma or ]
			char := ','
			if i == len(batchResults)-1 {
				char = ']'
			}
			_, err = r.remote.Write([]byte{byte(char)})
			if err != nil {
				return err
			}
		}
		err = r.remote.Flush()
a's avatar
jx  
a committed
		if err != nil {
			return err
		}
a's avatar
ok  
a committed
	} else if totalRequests == 0 {
a's avatar
a committed
		// all notification, so immediately flush
		err := r.mu.Acquire(ctx, 1)
a's avatar
ok  
a committed
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
a's avatar
a committed
		err = r.remote.Flush()
a's avatar
ok  
a committed
		if err != nil {
			return err
		}
a's avatar
a committed
	}
a's avatar
a committed
	returnWg.Wait()
a's avatar
a committed
	return nil
a's avatar
a committed
}

a's avatar
ok  
a committed
type callResponder struct {
a's avatar
a committed
	remote jsonrpc.ReaderWriter
a's avatar
ok  
a committed
	mu     *semaphore.Weighted

	batch        bool
	batchStarted bool
}

a's avatar
a committed
type callEnv struct {
a's avatar
a committed
	v   any
	err error
	id  *jsonrpc.ID
a's avatar
a committed
}

a's avatar
a committed
func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
a's avatar
ok  
a committed
	enc := jx.GetEncoder()
	defer jx.PutEncoder(enc)
	enc.Grow(4096)
	enc.ResetWriter(c.remote)
	enc.Obj(func(e *jx.Encoder) {
		e.Field("jsonrpc", func(e *jx.Encoder) {
			e.Str("2.0")
		})
		if env.id != nil {
			e.Field("id", func(e *jx.Encoder) {
				e.Raw(env.id.RawMessage())
			})
a's avatar
a committed
		}
a's avatar
ok  
a committed
		if env.err != nil {
			e.Field("error", func(e *jx.Encoder) {
a's avatar
a committed
				jsonrpc.EncodeError(e, env.err)
a's avatar
ok  
a committed
			})
		} else {
			// if there is no error, we try to marshal the result
			e.Field("result", func(e *jx.Encoder) {
				if env.v != nil {
a's avatar
a committed
					switch cast := (env.v).(type) {
a's avatar
ok  
a committed
					case json.RawMessage:
a's avatar
a committed
						if len(cast) == 0 {
							e.Null()
						} else {
							e.Raw(cast)
						}
a's avatar
a committed
					case func(e *jx.Encoder) error:
						err = cast(e)
a's avatar
ok  
a committed
					default:
						err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
							eo.DisableNewline = true
						})
					}
				} else {
					e.Null()
a's avatar
jx  
a committed
				}
a's avatar
a  
a committed
			})
a's avatar
a committed
		}
a's avatar
a committed
		if env.err == nil && err != nil {
			e.Field("error", func(e *jx.Encoder) {
				jsonrpc.EncodeError(e, err)
			})
		}
a's avatar
jx  
a committed
	})
a's avatar
ok  
a committed
	// a json encoding error here is possibly fatal....
	err = enc.Close()
a's avatar
a committed
	if err != nil {
		return err
	}
	return nil
}
a's avatar
ok  
a committed

type notifyEnv struct {
	method string
	dat    any
}

func (c *callResponder) notify(ctx context.Context, env *notifyEnv) (err error) {
a's avatar
a committed
	msg := &jsonrpc.Message{}
a's avatar
ok  
a committed
	//  allocate a temp buffer for this packet
	buf := bufpool.GetStd()
	defer bufpool.PutStd(buf)
	err = json.NewEncoder(buf).Encode(env.dat)
	if err != nil {
		msg.Error = err
	} else {
		msg.Params = buf.Bytes()
	}
	// add the method
	msg.Method = env.method
	enc := jx.GetEncoder()
	defer jx.PutEncoder(enc)
	enc.Grow(4096)
	enc.ResetWriter(c.remote)
a's avatar
a committed
	err = jsonrpc.MarshalMessage(msg, enc)
a's avatar
ok  
a committed
	if err != nil {
		return err
	}
	return enc.Close()
}