good morning!!!!

Skip to content
Snippets Groups Projects
Verified Commit f18b95a0 authored by a's avatar a
Browse files

simplify the server

parent ad0135e2
No related branches found
No related tags found
No related merge requests found
Pipeline #51109 passed with stage
in 2 minutes and 50 seconds
...@@ -116,9 +116,11 @@ func (c *HttpCodec) PeerInfo() jsonrpc.PeerInfo { ...@@ -116,9 +116,11 @@ func (c *HttpCodec) PeerInfo() jsonrpc.PeerInfo {
func (c *HttpCodec) ReadBatch(ctx context.Context) ([]*jsonrpc.Message, bool, error) { func (c *HttpCodec) ReadBatch(ctx context.Context) ([]*jsonrpc.Message, bool, error) {
if c.msgs == nil { if c.msgs == nil {
return nil, false, io.EOF return nil, false, context.Canceled
} }
c.msgs = nil defer func() {
c.msgs = nil
}()
return c.msgs.Messages, c.msgs.Batch, nil return c.msgs.Messages, c.msgs.Batch, nil
} }
......
...@@ -51,51 +51,10 @@ func confirmStatusCode(t *testing.T, got, want int) { ...@@ -51,51 +51,10 @@ func confirmStatusCode(t *testing.T, got, want int) {
t.Fatalf("response status code: got %d, want %d", got, want) t.Fatalf("response status code: got %d, want %d", got, want)
} }
func confirmRequestValidationCode(t *testing.T, method, contentType, body string, expectedStatusCode int) {
t.Helper()
request := httptest.NewRequest(method, "http://url.com", strings.NewReader(body))
if len(contentType) > 0 {
request.Header.Set("Content-Type", contentType)
}
code, err := ValidateRequest(request)
if code == 0 {
if err != nil {
t.Errorf("validation: got error %v, expected nil", err)
}
} else if err == nil {
t.Errorf("validation: code %d: got nil, expected error", code)
}
confirmStatusCode(t, code, expectedStatusCode)
}
func TestHTTPErrorResponseWithDelete(t *testing.T) {
confirmRequestValidationCode(t, http.MethodDelete, contentType, "", http.StatusMethodNotAllowed)
}
func TestHTTPErrorResponseWithPut(t *testing.T) {
confirmRequestValidationCode(t, http.MethodPut, contentType, "", http.StatusMethodNotAllowed)
}
func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
body := make([]rune, maxRequestContentLength+1)
confirmRequestValidationCode(t,
http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
}
//NOTE: this test is not needed since we no longer check this
//
//func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) {
// confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType)
//}
func TestHTTPErrorResponseWithValidRequest(t *testing.T) {
confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0)
}
func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) { func confirmHTTPRequestYieldsStatusCode(t *testing.T, method, contentType, body string, expectedStatusCode int) {
t.Helper() t.Helper()
s := server.NewServer(jmux.NewMux()) s := server.NewServer(jmux.NewMux())
ts := httptest.NewServer(&Server{Server: s}) ts := httptest.NewServer(HttpHandler(s))
defer ts.Close() defer ts.Close()
request, err := http.NewRequest(method, ts.URL, strings.NewReader(body)) request, err := http.NewRequest(method, ts.URL, strings.NewReader(body))
...@@ -119,7 +78,7 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) { ...@@ -119,7 +78,7 @@ func TestHTTPResponseWithEmptyGet(t *testing.T) {
// This checks that maxRequestContentLength is not applied to the response of a request. // This checks that maxRequestContentLength is not applied to the response of a request.
func TestHTTPRespBodyUnlimited(t *testing.T) { func TestHTTPRespBodyUnlimited(t *testing.T) {
s := jrpctest.NewServer() s := jrpctest.NewServer()
ts := httptest.NewServer(&Server{Server: s}) ts := httptest.NewServer(HttpHandler(s))
defer ts.Close() defer ts.Close()
c, err := DialHTTP(ts.URL) c, err := DialHTTP(ts.URL)
...@@ -178,7 +137,7 @@ func TestHTTPErrorResponse(t *testing.T) { ...@@ -178,7 +137,7 @@ func TestHTTPErrorResponse(t *testing.T) {
func TestClientHTTP(t *testing.T) { func TestClientHTTP(t *testing.T) {
s := jrpctest.NewServer() s := jrpctest.NewServer()
ts := httptest.NewServer(&Server{Server: s}) ts := httptest.NewServer(HttpHandler(s))
defer ts.Close() defer ts.Close()
c, err := DialHTTP(ts.URL) c, err := DialHTTP(ts.URL)
if err != nil { if err != nil {
......
...@@ -7,9 +7,11 @@ import ( ...@@ -7,9 +7,11 @@ import (
"sync" "sync"
"github.com/mailgun/multibuf" "github.com/mailgun/multibuf"
"golang.org/x/sync/errgroup"
"gfx.cafe/open/jrpc/pkg/jjson" "gfx.cafe/open/jrpc/pkg/jjson"
"gfx.cafe/open/jrpc/pkg/jsonrpc" "gfx.cafe/open/jrpc/pkg/jsonrpc"
"gfx.cafe/open/jrpc/pkg/serverutil"
) )
// Server is an RPC server. // Server is an RPC server.
...@@ -30,52 +32,66 @@ func NewServer(r jsonrpc.Handler) *Server { ...@@ -30,52 +32,66 @@ func NewServer(r jsonrpc.Handler) *Server {
} }
// ServeCodec reads incoming requests from codec, calls the appropriate callback and writes // ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
// the response back using the given codec. It will block until the codec is closed // the response back using the given codec. It will block until the codec is closed.
// the codec will return if either of these conditions are met
// 1. every request read from ReadBatch until ReadBatch returns context.Canceled is processed.
// 2. there is a server related error (failed encoding, broken conn) that was received while processing/reading messages.
func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error { func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) error {
defer remote.Close() defer remote.Close()
stream := jsonrpc.NewStream(remote) stream := jsonrpc.NewStream(remote)
// add a cancel to the context so we can cancel all the child tasks on return // add a cancel to the context so we can cancel all the child tasks on return
ctx = ContextWithPeerInfo(ctx, remote.PeerInfo()) ctx = ContextWithPeerInfo(ctx, remote.PeerInfo())
ctx = ContextWithMessageStream(ctx, stream) ctx = ContextWithMessageStream(ctx, stream)
ctx, cn := context.WithCancel(ctx) ctx, cn := context.WithCancel(ctx)
defer cn() defer cn()
errCh := make(chan error)
var allErrs []error batches := make(chan serverutil.Bundle, 1)
var mu sync.Mutex go func() {
wg := sync.WaitGroup{} defer close(batches)
err := func() error {
for { for {
// read messages from the stream synchronously // read messages from the stream synchronously
incoming, batch, err := remote.ReadBatch(ctx) incoming, batch, err := remote.ReadBatch(ctx)
if err != nil { if err != nil {
return err // if its not context canceled, aka our graceful closure, we error, otherwise we only return
} // in both cases we close the batches channel. this error will then immediately return.
wg.Add(1) if !errors.Is(err, context.Canceled) {
go func() { errCh <- err
defer wg.Done()
responder := &callResponder{
peerinfo: remote.PeerInfo(),
batch: batch,
stream: stream,
}
err = s.serve(ctx, incoming, responder)
if err != nil {
mu.Lock()
defer mu.Unlock()
allErrs = append(allErrs, err)
} }
}() return
}
batches <- serverutil.Bundle{
Messages: incoming,
Batch: batch,
}
} }
}() }()
wg.Wait() wg := sync.WaitGroup{}
if err != nil { // this errgroup controls the max concurrent requests per codec
allErrs = append(allErrs, err) egg := errgroup.Group{}
for batch := range batches {
incoming, batch := batch.Messages, batch.Batch
wg.Add(1)
responder := &callResponder{
peerinfo: remote.PeerInfo(),
batch: batch,
stream: stream,
}
egg.Go(func() error {
return s.serve(ctx, incoming, responder)
})
} }
if len(allErrs) > 0 { go func() {
return errors.Join(allErrs...) err := egg.Wait()
if err != nil {
errCh <- err
return
}
errCh <- nil
}()
select {
case err := <-errCh:
return err
} }
return nil
} }
func (s *Server) Shutdown(ctx context.Context) { func (s *Server) Shutdown(ctx context.Context) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment