good morning!!!!

Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
server.go 8.05 KiB
package server

import (
	"context"
	"errors"
	"sync"

	"golang.org/x/sync/semaphore"

	"gfx.cafe/open/jrpc/pkg/jsonrpc"

	"gfx.cafe/util/go/bufpool"

	"github.com/go-faster/jx"
	"github.com/goccy/go-json"
)

// Server is an RPC server.
// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
// a server can be used to listenandserve multiple codecs at a time
type Server struct {
	services jsonrpc.Handler

	lctx context.Context
	cn   context.CancelFunc
}

// NewServer creates a new server instance with no registered handlers.
func NewServer(r jsonrpc.Handler) *Server {
	server := &Server{services: r}
	server.lctx, server.cn = context.WithCancel(context.Background())
	return server
}

// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed
func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error {
	defer remote.Close()

	sema := semaphore.NewWeighted(1)
	// add a cancel to the context so we can cancel all the child tasks on return
	ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
	defer cn()

	allErrs := []error{}
	var mu sync.Mutex
	wg := sync.WaitGroup{}
	err := func() error {
		for {
			// read messages from the stream synchronously
			incoming, batch, err := remote.ReadBatch(ctx)
			if err != nil {
				return err
			}
			wg.Add(1)
			go func() {
				defer wg.Done()
				responder := &callResponder{
					remote: remote,
					batch:  batch,
					mu:     sema,
				}
				err = s.serve(ctx, incoming, responder)
				if err != nil {
					mu.Lock()
					defer mu.Unlock()
					allErrs = append(allErrs, err)
				}
			}()
		}
	}()
	wg.Wait()
	allErrs = append(allErrs, err)
	if len(allErrs) > 0 {
		return errors.Join(allErrs...)
	}
	return nil
}

func (s *Server) Shutdown(ctx context.Context) {
	s.cn()
}

func (s *Server) serve(ctx context.Context,
	incoming []*jsonrpc.Message,
	r *callResponder,
) error {
	if r.batch {
		return s.serveBatch(ctx, incoming, r)
	} else {
		return s.serveSingle(ctx, incoming[0], r)
	}
}

func (s *Server) serveSingle(ctx context.Context,
	incoming *jsonrpc.Message,
	r *callResponder,
) error {
	rw := &streamingRespWriter{
		ctx: ctx,
		cr:  r,
	}
	rw.msg, rw.err = produceOutputMessage(incoming)
	req := jsonrpc.NewRequestFromMessage(
		ctx,
		rw.msg,
	)
	req.Peer = r.remote.PeerInfo()
	if rw.msg.ID == nil {
		// all notification, so immediately flush
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err = r.remote.Flush()
		if err != nil {
			return err
		}
	}
	s.services.ServeRPC(rw, req)
	if rw.sendCalled == false && rw.msg.ID != nil {
		rw.Send(jsonrpc.Null, nil)
	}
	return nil
}

func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message, err error) {
	// a nil incoming message means return an invalid request.
	if inputMessage == nil {
		inputMessage = &jsonrpc.Message{ID: jsonrpc.NewNullIDPtr()}
		err = jsonrpc.NewInvalidRequestError("invalid request")
	}
	out = inputMessage
	out.Error = nil
	// zero length method is always invalid request
	if len(out.Method) == 0 {
		// assume if the method is not there AND the id is not there that it's an invalid REQUEST not notification
		// this makes sure we add 1 to totalRequests
		if out.ID == nil {
			out.ID = jsonrpc.NewNullIDPtr()
		}
		err = jsonrpc.NewInvalidRequestError("invalid request")
	}

	return
}

func (s *Server) serveBatch(ctx context.Context,
	incoming []*jsonrpc.Message,
	r *callResponder,
) error {
	// check for empty batch
	if r.batch && len(incoming) == 0 {
		// if it is empty batch, send the empty batch error and immediately return
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err = r.send(ctx, &callEnv{
			id:  jsonrpc.NewNullIDPtr(),
			err: jsonrpc.NewInvalidRequestError("empty batch"),
		})
		if err != nil {
			return err
		}
		err = r.remote.Flush()
		if err != nil {
			return err
		}
		return nil
	}

	rs := []*batchingRespWriter{}

	totalRequests := 0
	// populate the envelope we are about to send. this is synchronous pre-prpcessing
	for _, v := range incoming {
		// create the response writer
		rw := &batchingRespWriter{
			ctx: ctx,
			cr:  r,
		}
		rs = append(rs, rw)
		rw.msg, rw.err = produceOutputMessage(v)
		// requests and malformed requests both count as requests
		if rw.msg.ID != nil {
			totalRequests += 1
		}
	}

	// create a waitgroup for when every handler returns
	returnWg := sync.WaitGroup{}
	returnWg.Add(len(rs))
	// for each item in the envelope
	peerInfo := r.remote.PeerInfo()
	batchResults := []*batchingRespWriter{}

	respWg := &sync.WaitGroup{}
	respWg.Add(totalRequests)

	for _, vRef := range rs {
		v := vRef
		if v.msg.ID != nil {
			v.wg = respWg
			batchResults = append(batchResults, v)
		}
		// now process each request in its own goroutine
		// TODO: stress test this.
		go func() {
			defer returnWg.Done()
			req := jsonrpc.NewRequestFromMessage(
				ctx,
				v.msg,
			)
			req.Peer = peerInfo
			s.services.ServeRPC(v, req)
			if v.sendCalled == false && v.err == nil {
				v.Send(jsonrpc.Null, nil)
			}
		}()
	}

	if totalRequests > 0 {
		// TODO: channel?
		respWg.Wait()
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		// write them, one by one
		_, err = r.remote.Write([]byte{'['})
		if err != nil {
			return err
		}
		for i, v := range batchResults {
			err = r.send(ctx, &callEnv{
				v:   v.payload,
				err: v.err,
				id:  v.msg.ID,
			})
			if err != nil {
				return err
			}
			// write the comma or ]
			char := ','
			if i == len(batchResults)-1 {
				char = ']'
			}
			_, err = r.remote.Write([]byte{byte(char)})
			if err != nil {
				return err
			}
		}
		err = r.remote.Flush()
		if err != nil {
			return err
		}
	} else if totalRequests == 0 {
		// all notification, so immediately flush
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err = r.remote.Flush()
		if err != nil {
			return err
		}
	}
	returnWg.Wait()
	return nil
}

type callResponder struct {
	remote jsonrpc.ReaderWriter
	mu     *semaphore.Weighted
	batch        bool
	batchStarted bool
}

type callEnv struct {
	v   any
	err error
	id  *jsonrpc.ID
}

func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
	enc := jx.GetEncoder()
	defer jx.PutEncoder(enc)
	enc.Grow(4096)
	enc.ResetWriter(c.remote)
	enc.Obj(func(e *jx.Encoder) {
		e.Field("jsonrpc", func(e *jx.Encoder) {
			e.Str("2.0")
		})
		if env.id != nil {
			e.Field("id", func(e *jx.Encoder) {
				e.Raw(env.id.RawMessage())
			})
		}
		if env.err != nil {
			e.Field("error", func(e *jx.Encoder) {
				jsonrpc.EncodeError(e, env.err)
			})
		} else {
			// if there is no error, we try to marshal the result
			e.Field("result", func(e *jx.Encoder) {
				if env.v != nil {
					switch cast := (env.v).(type) {
					case json.RawMessage:
						if len(cast) == 0 {
							e.Null()
						} else {
							e.Raw(cast)
						}
					case func(e *jx.Encoder) error:
						err = cast(e)
					default:
						err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
							eo.DisableNewline = true
						})
					}
				} else {
					e.Null()
				}
			})
		}
		if env.err == nil && err != nil {
			e.Field("error", func(e *jx.Encoder) {
				jsonrpc.EncodeError(e, err)
			})
		}
	})
	// a json encoding error here is possibly fatal....
	err = enc.Close()
	if err != nil {
		return err
	}
	return nil
}

type notifyEnv struct {
	method string
	dat    any
}
func (c *callResponder) notify(ctx context.Context, env *notifyEnv) (err error) {
	msg := &jsonrpc.Message{}
	//  allocate a temp buffer for this packet
	buf := bufpool.GetStd()
	defer bufpool.PutStd(buf)
	err = json.NewEncoder(buf).Encode(env.dat)
	if err != nil {
		msg.Error = err
	} else {
		msg.Params = buf.Bytes()
	}
	// add the method
	msg.Method = env.method
	enc := jx.GetEncoder()
	defer jx.PutEncoder(enc)
	enc.Grow(4096)
	enc.ResetWriter(c.remote)
	err = jsonrpc.MarshalMessage(msg, enc)
	if err != nil {
		return err
	}
	return enc.Close()
}