good morning!!!!

Skip to content
Snippets Groups Projects
server.go 5.73 KiB
Newer Older
a's avatar
a committed
package server
a's avatar
rpc
a committed

import (
a's avatar
a committed
	"bytes"
a's avatar
rpc
a committed
	"context"
a's avatar
wg  
a committed
	"sync"
a's avatar
rpc
a committed

a's avatar
a committed
	"gfx.cafe/open/jrpc/pkg/codec"

a's avatar
a committed
	"gfx.cafe/util/go/bufpool"

	"github.com/go-faster/jx"
	"github.com/goccy/go-json"
a's avatar
rpc
a committed
)

// Server is an RPC server.
a's avatar
a  
a committed
// it is in charge of calling the handler on the message object, the json encoding of responses, and dealing with batch semantics.
// a server can be used to listenandserve multiple codecs at a time
a's avatar
rpc
a committed
type Server struct {
a's avatar
a committed
	services codec.Handler
a's avatar
rpc
a committed
}

// NewServer creates a new server instance with no registered handlers.
a's avatar
a committed
func NewServer(r codec.Handler) *Server {
	server := &Server{services: r}
a's avatar
rpc
a committed
	return server
}

// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed
func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) error {
	defer remote.Close()
	responder := &callResponder{
		remote: remote,
	}
	// add a cancel to the context so we can cancel all the child tasks on return
	ctx, cn := context.WithCancel(ContextWithPeerInfo(ctx, remote.PeerInfo()))
	defer cn()

	errch := make(chan error)
	go func() {
		for {
			// read messages from the stream synchronously
			incoming, batch, err := remote.ReadBatch(ctx)
			if err != nil {
				errch <- err
				return
			}
			go func() {
				err = s.serveBatch(ctx, incoming, batch, remote, responder)
				if err != nil {
					errch <- err
					return
				}
			}()
		}
	}()
	// exit on either the first error, or the context closing.
	select {
	case <-ctx.Done():
		return nil
	case err := <-errch:
		return err
	}
}

a's avatar
a committed
func (s *Server) serveBatch(ctx context.Context,
	incoming []*codec.Message,
	batch bool,
	remote codec.ReaderWriter, responder *callResponder) error {
a's avatar
a committed
	env := &callEnv{
		batch: batch,
	}
a's avatar
a committed

a's avatar
a committed
	// check for empty batch
a's avatar
a committed
	if batch && len(incoming) == 0 {
a's avatar
a  
a committed
		// if it is empty batch, send the empty batch error and immediately return
		return responder.send(ctx, &callEnv{
a's avatar
a committed
			responses: []*callRespWriter{{
a's avatar
a committed
				pkt: &codec.Message{
					ID:    codec.NewNullIDPtr(),
					Error: codec.NewInvalidRequestError("empty batch"),
				},
a's avatar
a committed
			}},
			batch: false,
a's avatar
a  
a committed
		})
a's avatar
a committed
	}

a's avatar
a  
a committed
	// populate the envelope we are about to send. this is synchronous pre-prpcessing
a's avatar
a committed
	for _, v := range incoming {
a's avatar
a  
a committed
		// create the response writer
a's avatar
a committed
		rw := &callRespWriter{
a's avatar
a  
a committed
			notifications: func(env *notifyEnv) error { return responder.notify(ctx, env) },
a's avatar
a committed
			header:        remote.PeerInfo().HTTP.Headers,
		}
a's avatar
a committed
		env.responses = append(env.responses, rw)
a's avatar
a  
a committed
		// a nil incoming message means an empty response
		if v == nil {
			rw.msg = &codec.Message{ID: codec.NewNullIDPtr()}
			rw.pkt = &codec.Message{ID: codec.NewNullIDPtr()}
			continue
		}
		rw.msg = v
		if v.ID == nil {
			rw.pkt = &codec.Message{ID: codec.NewNullIDPtr()}
			continue
		}
		rw.pkt = &codec.Message{ID: v.ID}
a's avatar
a committed
	}

	// create a waitgroup
	wg := sync.WaitGroup{}
a's avatar
ok  
a committed
	wg.Add(len(env.responses))
a's avatar
a  
a committed
	// for each item in the envelope
	peerInfo := remote.PeerInfo()
a's avatar
a committed
	for _, vRef := range env.responses {
		v := vRef
a's avatar
a  
a committed
		// process each request in its own goroutine
a's avatar
a committed
		go func() {
			defer wg.Done()
a's avatar
a  
a committed
			// early respond to nil requests
			if v.msg == nil || len(v.msg.Method) == 0 {
				v.pkt.Error = codec.NewInvalidRequestError("invalid request")
				return
			}
			if v.msg.ID == nil || v.msg.ID.IsNull() {
				// it's a notification, so we mark skip and we don't write anything for it
				v.skip = true
				return
			}
a's avatar
a committed
			r := codec.NewRequestFromMessage(
a's avatar
a committed
				ctx,
a's avatar
a committed
				v.msg,
			)
a's avatar
a  
a committed
			r.Peer = peerInfo
a's avatar
a committed
			s.services.ServeRPC(v, r)
a's avatar
a committed
		}()
	}
	wg.Wait()
a's avatar
a  
a committed
	return responder.send(ctx, env)
a's avatar
a committed
}

a's avatar
a committed
type callResponder struct {
a's avatar
a  
a committed
	remote codec.ReaderWriter
a's avatar
a committed
	mu     sync.Mutex
a's avatar
a committed
}
a's avatar
a committed

type notifyEnv struct {
	method string
	dat    any
	extra  []codec.RequestField
}

a's avatar
a committed
func (c *callResponder) notify(ctx context.Context, env *notifyEnv) error {
a's avatar
jx  
a committed
	err := c.remote.Send(func(e *jx.Encoder) error {
		msg := &codec.Message{}
		var err error
		//  allocate a temp buffer for this packet
		buf := bufpool.GetStd()
		defer bufpool.PutStd(buf)
		err = json.NewEncoder(buf).Encode(env.dat)
		if err != nil {
			msg.Error = err
		} else {
			msg.Params = buf.Bytes()
		}
		msg.ExtraFields = env.extra
		// add the method
		msg.Method = env.method
		err = codec.MarshalMessage(msg, e)
		if err != nil {
			return err
		}
		return nil
	})
a's avatar
a committed
	if err != nil {
		return err
	}
	return nil
a's avatar
a committed
}

type callEnv struct {
	responses []*callRespWriter
	batch     bool
a's avatar
a committed
}

a's avatar
a committed
func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) {
a's avatar
ok  
a committed
	// notification gets nothing
a's avatar
a committed
	// if all msgs in batch are notification, we trigger an allSkip and write nothing
a's avatar
ok  
a committed
	if env.batch {
		allSkip := true
		for _, v := range env.responses {
			if v.skip != true {
				allSkip = false
			}
		}
		if allSkip {
a's avatar
jx  
a committed
			return c.remote.Send(func(e *jx.Encoder) error { return nil })
a's avatar
ok  
a committed
		}
	}
a's avatar
a committed
	// create the streaming encoder
a's avatar
jx  
a committed
	err = c.remote.Send(func(enc *jx.Encoder) error {
		if env.batch {
			enc.ArrStart()
a's avatar
a committed
		}
a's avatar
jx  
a committed
		for _, v := range env.responses {
			msg := v.pkt
			// if we are a batch AND we are supposed to skip, then continue
			// this means that for a non-batch notification, we do not skip! this is to ensure we get always a "response" for http-like endpoints
			if env.batch && v.skip {
				continue
			}
			// if there is no error, we try to marshal the result
			if msg.Error == nil {
				buf := bufpool.GetStd()
				defer bufpool.PutStd(buf)
				je := json.NewEncoder(buf)
				err = je.EncodeWithOption(v.dat)
				if err != nil {
					msg.Error = err
				} else {
					msg.Result = buf.Bytes()
					msg.Result = bytes.TrimSuffix(msg.Result, []byte{'\n'})
				}
			}
			// then marshal the whole message into the stream
			err := codec.MarshalMessage(msg, enc)
a's avatar
a committed
			if err != nil {
a's avatar
jx  
a committed
				return err
a's avatar
a committed
			}
a's avatar
a committed
		}
a's avatar
jx  
a committed
		if env.batch {
			enc.ArrEnd()
a's avatar
a committed
		}
a's avatar
jx  
a committed
		return nil
	})
a's avatar
a committed
	if err != nil {
		return err
	}
	return nil
}