From f18b95a0f4b400e327e6d6eff7866d7a15b09c99 Mon Sep 17 00:00:00 2001 From: a <a@tuxpa.in> Date: Tue, 26 Dec 2023 01:10:13 -0600 Subject: [PATCH] simplify the server --- contrib/codecs/http/codecs.go | 6 ++- contrib/codecs/http/http_test.go | 47 ++------------------ pkg/server/server.go | 74 +++++++++++++++++++------------- 3 files changed, 52 insertions(+), 75 deletions(-) diff --git a/contrib/codecs/http/codecs.go b/contrib/codecs/http/codecs.go index c14e4a9..c5e42b5 100644 --- a/contrib/codecs/http/codecs.go +++ b/contrib/codecs/http/codecs.go @@ -116,9 +116,11 @@ func (c *HttpCodec) PeerInfo() jsonrpc.PeerInfo { func (c *HttpCodec) ReadBatch(ctx context.Context) ([]*jsonrpc.Message, bool, error) { if c.msgs == nil { - return nil, false, io.EOF + return nil, false, context.Canceled } - c.msgs = nil + defer func() { + c.msgs = nil + }() return c.msgs.Messages, c.msgs.Batch, nil } diff --git a/contrib/codecs/http/http_test.go b/contrib/codecs/http/http_test.go index 1e8c456..bcf9a29 100644 --- a/contrib/codecs/http/http_test.go +++ b/contrib/codecs/http/http_test.go @@ -51,51 +51,10 @@ func confirmStatusCode(t *testing.T, got, want int) { t.Fatalf("response status code: got %d, want %d", got, want) } -func confirmRequestValidationCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { - t.Helper() - request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body)) - if len(contentType) > 0 { - request.Header.Set("Content-Type", contentType) - } - code, err := ValidateRequest(request) - if code == 0 { - if err != nil { - t.Errorf("validation: got error %v, expected nil", err) - } - } else if err == nil { - t.Errorf("validation: code %d: got nil, expected error", code) - } - confirmStatusCode(t, code, expectedStatusCode) -} - -func TestHTTPErrorResponseWithDelete(t *testing.T) { - confirmRequestValidationCode(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed) -} - -func TestHTTPErrorResponseWithPut(t *testing.T) { - confirmRequestValidationCode(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed) -} - -func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) { - body := make([]rune, maxRequestContentLength+1) - confirmRequestValidationCode(t, - http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge) -} - -//NOTE: this test is not needed since we no longer check this -// -//func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) { -// confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType) -//} - -func TestHTTPErrorResponseWithValidRequest(t *testing.T) { - confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0) -} - func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { t.Helper() s := server.NewServer(jmux.NewMux()) - ts := httptest.NewServer(&Server{Server: s}) + ts := httptest.NewServer(HttpHandler(s)) defer ts.Close() request, err := http.NewRequest(method, ts.URL, strings.NewReader(body)) @@ -119,7 +78,7 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) { // This checks that maxRequestContentLength is not applied to the response of a request. func TestHTTPRespBodyUnlimited(t *testing.T) { s := jrpctest.NewServer() - ts := httptest.NewServer(&Server{Server: s}) + ts := httptest.NewServer(HttpHandler(s)) defer ts.Close() c, err := DialHTTP(ts.URL) @@ -178,7 +137,7 @@ func TestHTTPErrorResponse(t *testing.T) { func TestClientHTTP(t *testing.T) { s := jrpctest.NewServer() - ts := httptest.NewServer(&Server{Server: s}) + ts := httptest.NewServer(HttpHandler(s)) defer ts.Close() c, err := DialHTTP(ts.URL) if err != nil { diff --git a/pkg/server/server.go b/pkg/server/server.go index dc9aff3..a6554c0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -7,9 +7,11 @@ import ( "sync" "github.com/mailgun/multibuf" + "golang.org/x/sync/errgroup" "gfx.cafe/open/jrpc/pkg/jjson" "gfx.cafe/open/jrpc/pkg/jsonrpc" + "gfx.cafe/open/jrpc/pkg/serverutil" ) // Server is an RPC server. @@ -30,52 +32,66 @@ func NewServer(r jsonrpc.Handler) *Server { } // ServeCodec reads incoming requests from codec, calls the appropriate callback and writes -// the response back using the given codec. It will block until the codec is closed +// the response back using the given codec. It will block until the codec is closed. +// the codec will return if either of these conditions are met +// 1. every request read from ReadBatch until ReadBatch returns context.Canceled is processed. +// 2. there is a server related error (failed encoding, broken conn) that was received while processing/reading messages. func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error { defer remote.Close() - stream := jsonrpc.NewStream(remote) // add a cancel to the context so we can cancel all the child tasks on return ctx = ContextWithPeerInfo(ctx, remote.PeerInfo()) ctx = ContextWithMessageStream(ctx, stream) ctx, cn := context.WithCancel(ctx) defer cn() - - var allErrs []error - var mu sync.Mutex - wg := sync.WaitGroup{} - err := func() error { + errCh := make(chan error) + batches := make(chan serverutil.Bundle, 1) + go func() { + defer close(batches) for { // read messages from the stream synchronously incoming, batch, err := remote.ReadBatch(ctx) if err != nil { - return err - } - wg.Add(1) - go func() { - defer wg.Done() - responder := &callResponder{ - peerinfo: remote.PeerInfo(), - batch: batch, - stream: stream, - } - err = s.serve(ctx, incoming, responder) - if err != nil { - mu.Lock() - defer mu.Unlock() - allErrs = append(allErrs, err) + // if its not context canceled, aka our graceful closure, we error, otherwise we only return + // in both cases we close the batches channel. this error will then immediately return. + if !errors.Is(err, context.Canceled) { + errCh <- err } - }() + return + } + batches <- serverutil.Bundle{ + Messages: incoming, + Batch: batch, + } } }() - wg.Wait() - if err != nil { - allErrs = append(allErrs, err) + wg := sync.WaitGroup{} + // this errgroup controls the max concurrent requests per codec + egg := errgroup.Group{} + for batch := range batches { + incoming, batch := batch.Messages, batch.Batch + wg.Add(1) + responder := &callResponder{ + peerinfo: remote.PeerInfo(), + batch: batch, + stream: stream, + } + egg.Go(func() error { + return s.serve(ctx, incoming, responder) + }) } - if len(allErrs) > 0 { - return errors.Join(allErrs...) + go func() { + err := egg.Wait() + if err != nil { + errCh <- err + return + } + errCh <- nil + }() + select { + case err := <-errCh: + return err } - return nil } func (s *Server) Shutdown(ctx context.Context) { -- GitLab