diff --git a/pkg/server/limitio.go b/pkg/server/limitio.go new file mode 100644 index 0000000000000000000000000000000000000000..746e1c78d76605bed153d1029e433cdfad909b2e --- /dev/null +++ b/pkg/server/limitio.go @@ -0,0 +1,72 @@ +package server + +import ( + "errors" + "fmt" + "io" +) + +var _ io.Writer = (*Writer)(nil) +var ErrThresholdExceeded = errors.New("stream size exceeds threshold") + +// Writer wraps w with writing length limit. +// +// To create Writer, use NewWriter(). +type Writer struct { + w io.Writer + written int + limit int + regardOverSizeNormal bool +} + +// NewWriter create a writer that writes at most n bytes. +// +// regardOverSizeNormal controls whether Writer.Write() returns error +// when writing totally more bytes than n, or do no-op to inner w, +// pretending writing is processed normally. +func newWriter(w io.Writer, n int, regardOverSizeNormal bool) *Writer { + return &Writer{ + w: w, + written: 0, + limit: n, + regardOverSizeNormal: regardOverSizeNormal, + } +} + +// Writer implements io.Writer +func (lw *Writer) Write(p []byte) (n int, err error) { + if lw.written >= lw.limit { + if lw.regardOverSizeNormal { + n = len(p) + lw.written += n + return + } + + err = fmt.Errorf("threshold is %d bytes: %w", lw.limit, ErrThresholdExceeded) + return + } + + var ( + overSized bool + originalLen int + ) + + left := lw.limit - lw.written + if originalLen = len(p); originalLen > left { + overSized = true + p = p[0:left] + } + n, err = lw.w.Write(p) + lw.written += n + if overSized && err == nil { + // Write must return a non-nil error if it returns n < len(p). + if lw.regardOverSizeNormal { + return originalLen, nil + } + + err = fmt.Errorf("threshold is %d bytes: %w", lw.limit, ErrThresholdExceeded) + return + } + + return +} diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go index 404e5858664c8901e8f9da98275ab6cc4900e737..a25fbbdb09a8ac745d708c3da3fd6524b3be5a10 100644 --- a/pkg/server/responsewriter.go +++ b/pkg/server/responsewriter.go @@ -1,15 +1,22 @@ package server import ( + "bytes" "context" "net/http" "sync" "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/util/go/bufpool" "github.com/goccy/go-json" "golang.org/x/sync/semaphore" ) +// 16mb... should be more than enough for any batch. +// you shouldn't be batching more than this +// TODO: make this configurable +const maxBatchSizeBytes = 1024 * 1024 * 1024 * 16 + var _ codec.ResponseWriter = (*callRespWriter)(nil) // callRespWriter is NOT thread safe @@ -54,10 +61,14 @@ func (c *callRespWriter) Send(v any, e error) (err error) { } if v != nil { // json marshaling errors are reported to the handler - c.payload, err = json.Marshal(v) + buf := bufpool.GlobalPool.GetStd() + w := newWriter(buf, maxBatchSizeBytes, false) + err = json.NewEncoder(w).Encode(v) if err != nil { return err } + c.payload = json.RawMessage(bytes.TrimSuffix(buf.Bytes(), []byte{'\n'})) + return nil } return nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index f517002a1d84efc3905ba75c1ff2d7e4e8e6d14e..e49209d5b5df5adf13e3032e6c204b316c022835 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -174,6 +174,7 @@ func (s *Server) serveBatch(ctx context.Context, }() } if r.batch && totalRequests > 0 { + err = doneMu.Acquire(ctx, int64(totalRequests)) if err != nil { return err