good morning!!!!

Skip to content
Snippets Groups Projects
server.go 6.98 KiB
Newer Older
a's avatar
a committed
package server
a's avatar
rpc
a committed

import (
	"context"
a's avatar
ok  
a committed
	"errors"
a's avatar
wg  
a committed
	"sync"
a's avatar
rpc
a committed

a's avatar
a committed
	"gfx.cafe/open/jrpc/pkg/codec"
a's avatar
ok  
a committed
	"golang.org/x/sync/semaphore"
a's avatar
a committed

a's avatar
a committed
	"gfx.cafe/util/go/bufpool"

	"github.com/go-faster/jx"
	"github.com/goccy/go-json"
a's avatar
rpc
a committed
)

// Server is an RPC server.
a's avatar
a  
a committed
// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
// a server can be used to listenandserve multiple codecs at a time
a's avatar
rpc
a committed
type Server struct {
a's avatar
a committed
	services codec.Handler
a's avatar
ok  
a committed

	lctx context.Context
	cn   context.CancelFunc
a's avatar
rpc
a committed
}

// NewServer creates a new server instance with no registered handlers.
a's avatar
a committed
func NewServer(r codec.Handler) *Server {
	server := &Server{services: r}
a's avatar
ok  
a committed
	server.lctx, server.cn = context.WithCancel(context.Background())
a's avatar
rpc
a committed
	return server
}

// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed
func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) error {
	defer remote.Close()
a's avatar
ok  
a committed

	sema := semaphore.NewWeighted(1)
	// add a cancel to the context so we can cancel all the child tasks on return
	ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
	defer cn()

a's avatar
ok  
a committed
	allErrs := []error{}
	var mu sync.Mutex
	wg := sync.WaitGroup{}
	err := func() error {
		for {
			// read messages from the stream synchronously
			incoming, batch, err := remote.ReadBatch(ctx)
			if err != nil {
a's avatar
ok  
a committed
				return err
a's avatar
ok  
a committed
			wg.Add(1)
			go func() {
a's avatar
ok  
a committed
				defer wg.Done()
				responder := &callResponder{
					remote: remote,
					batch:  batch,
					mu:     sema,
				}
				err = s.serveBatch(ctx, incoming, responder)
				if err != nil {
a's avatar
ok  
a committed
					mu.Lock()
					defer mu.Unlock()
					allErrs = append(allErrs, err)
a's avatar
ok  
a committed
	allErrs = append(allErrs, err)
	if len(allErrs) > 0 {
		return errors.Join(allErrs...)
a's avatar
ok  
a committed
	return nil
}

func (s *Server) Shutdown(ctx context.Context) {
	s.cn()
a's avatar
a committed
func (s *Server) serveBatch(ctx context.Context,
	incoming []*codec.Message,
a's avatar
ok  
a committed
	r *callResponder,
) error {
a's avatar
a committed
	// check for empty batch
a's avatar
ok  
a committed
	if r.batch && len(incoming) == 0 {
a's avatar
a  
a committed
		// if it is empty batch, send the empty batch error and immediately return
a's avatar
a committed
		err := r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err = r.send(ctx, &callEnv{
			id:  codec.NewNullIDPtr(),
			err: codec.NewInvalidRequestError("empty batch"),
a's avatar
a  
a committed
		})
a's avatar
ok  
a committed
		if err != nil {
			return err
		}
a's avatar
a committed
		err = r.remote.Flush()
		if err != nil {
			return err
		}
		return nil
a's avatar
a committed
	}

a's avatar
ok  
a committed
	rs := []*callRespWriter{}

	totalRequests := 0
a's avatar
a  
a committed
	// populate the envelope we are about to send. this is synchronous pre-prpcessing
a's avatar
a committed
	for _, v := range incoming {
a's avatar
a  
a committed
		// create the response writer
a's avatar
a committed
		rw := &callRespWriter{
a's avatar
ok  
a committed
			ctx: ctx,
			cr:  r,
a's avatar
a committed
		}
a's avatar
ok  
a committed
		rs = append(rs, rw)
a's avatar
a  
a committed
		// a nil incoming message means an empty response
		if v == nil {
a's avatar
ok  
a committed
			v = &codec.Message{ID: codec.NewNullIDPtr()}
a's avatar
a committed
			rw.err = codec.NewInvalidRequestError("invalid request")
a's avatar
a  
a committed
		}
		rw.msg = v
a's avatar
a committed
		rw.msg.ExtraFields = codec.ExtraFields{}
		rw.msg.Error = nil
a's avatar
ok  
a committed
		if len(v.Method) == 0 {
a's avatar
a committed
			if v.ID == nil {
				v.ID = codec.NewNullIDPtr()
			}
a's avatar
ok  
a committed
			rw.err = codec.NewInvalidRequestError("invalid request")
		}
a's avatar
ok  
a committed
		if v.ID != nil {
			totalRequests += 1
a's avatar
a  
a committed
		}
a's avatar
ok  
a committed
	}
	var doneMu *semaphore.Weighted
	doneMu = semaphore.NewWeighted(int64(totalRequests))
	err := doneMu.Acquire(ctx, int64(totalRequests))
	if err != nil {
		return err
a's avatar
a committed
	}

a's avatar
ok  
a committed
	// create a waitgroup for everything
a's avatar
a committed
	wg := sync.WaitGroup{}
a's avatar
ok  
a committed
	wg.Add(len(rs))
a's avatar
a  
a committed
	// for each item in the envelope
a's avatar
ok  
a committed
	peerInfo := r.remote.PeerInfo()
	batchResults := []*callRespWriter{}
	for _, vRef := range rs {
a's avatar
a committed
		v := vRef
a's avatar
ok  
a committed
		if r.batch {
a's avatar
ok  
a committed
			v.noStream = true
a's avatar
ok  
a committed
			if v.msg.ID != nil {
				v.doneMu = doneMu
				batchResults = append(batchResults, v)
			}
		}
a's avatar
ok  
a committed
		// now process each request in its own goroutine
		// TODO: stress test this.
a's avatar
a committed
		go func() {
			defer wg.Done()
a's avatar
ok  
a committed
			req := codec.NewRequestFromMessage(
a's avatar
a committed
				ctx,
a's avatar
a committed
				v.msg,
			)
a's avatar
ok  
a committed
			req.Peer = peerInfo
			s.services.ServeRPC(v, req)
a's avatar
a committed
		}()
	}
a's avatar
a committed
	if r.batch && totalRequests > 0 {
a's avatar
ok  
a committed
		err = doneMu.Acquire(ctx, int64(totalRequests))
		if err != nil {
			return err
		}
a's avatar
ok  
a committed
		err = r.mu.Acquire(ctx, 1)
a's avatar
jx  
a committed
		if err != nil {
a's avatar
ok  
a committed
			return err
a's avatar
jx  
a committed
		}
a's avatar
ok  
a committed
		defer r.mu.Release(1)
		// write them, one by one
		_, err = r.remote.Write([]byte{'['})
		if err != nil {
			return err
		}
		for i, v := range batchResults {
a's avatar
ok  
a committed
			var a any
			a = v.payload
a's avatar
ok  
a committed
			err = r.send(ctx, &callEnv{
a's avatar
ok  
a committed
				v:           &a,
a's avatar
ok  
a committed
				err:         v.err,
				id:          v.msg.ID,
				extrafields: v.msg.ExtraFields,
			})
			if err != nil {
				return err
			}
			// write the comma or ]
			char := ','
			if i == len(batchResults)-1 {
				char = ']'
			}
			_, err = r.remote.Write([]byte{byte(char)})
			if err != nil {
				return err
			}
		}
		err = r.remote.Flush()
a's avatar
jx  
a committed
		if err != nil {
			return err
		}
a's avatar
ok  
a committed
	} else if totalRequests == 0 {
		err = r.mu.Acquire(ctx, 1)
		if err != nil {
			return err
		}
		defer r.mu.Release(1)
		err := r.remote.Flush()
		if err != nil {
			return err
		}
a's avatar
a committed
	}
a's avatar
ok  
a committed
	wg.Wait()
a's avatar
a committed
	return nil
a's avatar
a committed
}

a's avatar
ok  
a committed
type callResponder struct {
	remote codec.ReaderWriter
	mu     *semaphore.Weighted

	batch        bool
	batchStarted bool
}

a's avatar
a committed
type callEnv struct {
a's avatar
ok  
a committed
	v           *any
a's avatar
ok  
a committed
	err         error
	id          *codec.ID
	extrafields codec.ExtraFields
a's avatar
a committed
}

a's avatar
a committed
func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
a's avatar
ok  
a committed
	enc := jx.GetEncoder()
	defer jx.PutEncoder(enc)
	enc.Grow(4096)
	enc.ResetWriter(c.remote)
	enc.Obj(func(e *jx.Encoder) {
		e.Field("jsonrpc", func(e *jx.Encoder) {
			e.Str("2.0")
		})
		if env.id != nil {
			e.Field("id", func(e *jx.Encoder) {
				e.Raw(env.id.RawMessage())
			})
a's avatar
a committed
		}
a's avatar
ok  
a committed
		if env.extrafields != nil {
			for k, v := range env.extrafields {
				e.Field(k, func(e *jx.Encoder) {
					e.Raw(v)
a's avatar
a  
a committed
				})
a's avatar
ok  
a committed
			}
		}
		if env.err != nil {
			e.Field("error", func(e *jx.Encoder) {
				codec.EncodeError(e, env.err)
			})
		} else {
			// if there is no error, we try to marshal the result
			e.Field("result", func(e *jx.Encoder) {
				if env.v != nil {
a's avatar
ok  
a committed
					switch cast := (*env.v).(type) {
a's avatar
ok  
a committed
					case json.RawMessage:
						e.Raw(cast)
					default:
						err = json.NewEncoder(e).EncodeWithOption(cast, func(eo *json.EncodeOption) {
							eo.DisableNewline = true
						})
						if err != nil {
							return
a's avatar
a  
a committed
						}
a's avatar
ok  
a committed
					}
				} else {
					e.Null()
a's avatar
jx  
a committed
				}
a's avatar
a  
a committed
			})
a's avatar
a committed
		}
a's avatar
jx  
a committed
	})
a's avatar
ok  
a committed
	// a json encoding error here is possibly fatal....
	if err != nil {
		return err
	}
	err = enc.Close()
a's avatar
a committed
	if err != nil {
		return err
	}
	return nil
}
a's avatar
ok  
a committed

type notifyEnv struct {
	method string
	dat    any
	extra  codec.ExtraFields
}

func (c *callResponder) notify(ctx context.Context, env *notifyEnv) (err error) {
	msg := &codec.Message{}
	//  allocate a temp buffer for this packet
	buf := bufpool.GetStd()
	defer bufpool.PutStd(buf)
	err = json.NewEncoder(buf).Encode(env.dat)
	if err != nil {
		msg.Error = err
	} else {
		msg.Params = buf.Bytes()
	}
	msg.ExtraFields = env.extra
	// add the method
	msg.Method = env.method
	enc := jx.GetEncoder()
	defer jx.PutEncoder(enc)
	enc.Grow(4096)
	enc.ResetWriter(c.remote)
	err = codec.MarshalMessage(msg, enc)
	if err != nil {
		return err
	}
	return enc.Close()
}