From 06d19b75f3ce0cd4c3790bb31c544945c5c13c86 Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Fri, 2 Feb 2024 03:16:21 -0600
Subject: [PATCH] ok

---
 contrib/codecs/websocket/codec.go   | 23 +++++++++++++++++++----
 contrib/codecs/websocket/handler.go |  2 +-
 pkg/server/responsewriter.go        |  4 ++++
 pkg/server/server.go                | 27 +++++++++++++++++++--------
 4 files changed, 43 insertions(+), 13 deletions(-)

diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go
index b5337c7..29e4744 100644
--- a/contrib/codecs/websocket/codec.go
+++ b/contrib/codecs/websocket/codec.go
@@ -20,6 +20,8 @@ import (
 type Codec struct {
 	closed chan struct{}
 	conn   *websocket.Conn
+	closer func()
+	ctx    context.Context
 
 	currentFrame io.WriteCloser
 	wrLock       sync.Mutex
@@ -32,14 +34,20 @@ type Codec struct {
 
 func newWebsocketCodec(ctx context.Context, conn *websocket.Conn, host string, req *http.Request) *Codec {
 	conn.SetReadLimit(WsMessageSizeLimit)
+
+	ctx, cn := context.WithCancel(ctx)
 	c := &Codec{
 		closed:  make(chan struct{}),
 		conn:    conn,
 		decLock: semaphore.NewWeighted(1),
+		ctx:     ctx,
+	}
+	c.closer = func() {
+		cn()
 	}
 	c.i.Transport = "ws"
 	// Fill in connection details.
-	c.i.HTTP = req.Clone(req.Context())
+	c.i.HTTP = req.Clone(ctx)
 	// Start pinger.
 	go heartbeat(ctx, conn, WsPingInterval)
 	return c
@@ -92,20 +100,26 @@ func (c *Codec) Write(p []byte) (n int, err error) {
 	c.wrLock.Lock()
 	defer c.wrLock.Unlock()
 	if c.currentFrame == nil {
-		wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
+		wr, err := c.conn.Writer(c.ctx, websocket.MessageText)
 		if err != nil {
+			c.Close()
 			return 0, err
 		}
 		c.currentFrame = wr
 	}
-	return c.currentFrame.Write(p)
+
+	n, err = c.currentFrame.Write(p)
+	if err != nil {
+		c.Close()
+	}
+	return
 }
 
 func (c *Codec) Flush() error {
 	c.wrLock.Lock()
 	defer c.wrLock.Unlock()
 	if c.currentFrame == nil {
-		wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
+		wr, err := c.conn.Writer(c.ctx, websocket.MessageText)
 		if err != nil {
 			return err
 		}
@@ -132,6 +146,7 @@ func (c *Codec) Close() error {
 	case <-c.closed:
 		return nil
 	default:
+		c.closer()
 		close(c.closed)
 	}
 	return c.conn.Close(websocket.StatusNormalClosure, "")
diff --git a/contrib/codecs/websocket/handler.go b/contrib/codecs/websocket/handler.go
index 2458a85..b3ef650 100644
--- a/contrib/codecs/websocket/handler.go
+++ b/contrib/codecs/websocket/handler.go
@@ -25,7 +25,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	c := newWebsocketCodec(r.Context(), conn, "", r)
 	err = s.Server.ServeCodec(r.Context(), c)
 	if err != nil {
-		// slog.Error("codec err", "error", err)
+		//slog.Error("codec err", "error", err)
 	}
 }
 
diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go
index fb7676c..0b665d0 100644
--- a/pkg/server/responsewriter.go
+++ b/pkg/server/responsewriter.go
@@ -12,6 +12,8 @@ var _ jsonrpc.ResponseWriter = (*streamingRespWriter)(nil)
 type streamingRespWriter struct {
 	// this should be the same context as the request
 	ctx context.Context
+	// if there is an unrecoverable error, this should be used to immediately terminate the codec
+	cancel func()
 	// the stream that Send will write to
 	sendStream jsonrpc.MessageStreamer
 	// the stream that Notify will write to
@@ -73,6 +75,7 @@ func (c *streamingRespWriter) Send(v any, e error) (err error) {
 	}
 	defer msg.Close()
 	if err = send(ce, msg); err != nil {
+		c.cancel()
 		return err
 	}
 	return nil
@@ -89,6 +92,7 @@ func (c *streamingRespWriter) Notify(method string, v any) error {
 		dat:    v,
 	}, msg)
 	if err != nil {
+		c.cancel()
 		return err
 	}
 	return nil
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 4954fdf..393b0c2 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -51,10 +51,12 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
 	ctx = ContextWithMessageStream(ctx, stream)
 	ctx, cn := context.WithCancel(ctx)
 	defer cn()
-	errCh := make(chan error)
+	errCh := make(chan error, 1)
 	batches := make(chan serverutil.Bundle, 1)
 	go func() {
-		defer close(batches)
+		defer func() {
+			close(batches)
+		}()
 		for {
 			// read messages from the stream synchronously
 			incoming, batch, err := remote.ReadBatch(ctx)
@@ -62,7 +64,10 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
 				// if its not context canceled, aka our graceful closure, we error, otherwise we only return
 				// in both cases we close the batches channel. this error will then immediately return.
 				if !errors.Is(err, context.Canceled) {
-					errCh <- err
+					select {
+					case errCh <- err:
+					default:
+					}
 				}
 				return
 			}
@@ -75,6 +80,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
 	wg := sync.WaitGroup{}
 	// this errgroup controls the max concurrent requests per codec
 	egg := errgroup.Group{}
+	egg.SetLimit(4)
 	for batch := range batches {
 		incoming, batch := batch.Messages, batch.Batch
 		wg.Add(1)
@@ -84,7 +90,9 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
 			stream:   stream,
 		}
 		egg.Go(func() error {
-			return s.serve(ctx, incoming, responder)
+			return s.serve(ctx, func() {
+				cn()
+			}, incoming, responder)
 		})
 	}
 	go func() {
@@ -95,6 +103,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
 		}
 		errCh <- nil
 	}()
+
 	select {
 	case err := <-errCh:
 		return err
@@ -106,23 +115,24 @@ func (s *Server) Shutdown(ctx context.Context) error {
 	return nil
 }
 
-func (s *Server) serve(ctx context.Context,
+func (s *Server) serve(ctx context.Context, cancelFunc func(),
 	incoming []*jsonrpc.Message,
 	r *callResponder,
 ) error {
 	if r.batch {
-		return s.serveBatch(ctx, incoming, r)
+		return s.serveBatch(ctx, cancelFunc, incoming, r)
 	} else {
-		return s.serveSingle(ctx, incoming[0], r)
+		return s.serveSingle(ctx, cancelFunc, incoming[0], r)
 	}
 }
 
-func (s *Server) serveSingle(ctx context.Context,
+func (s *Server) serveSingle(ctx context.Context, cancelFunc func(),
 	incoming *jsonrpc.Message,
 	r *callResponder,
 ) error {
 	rw := &streamingRespWriter{
 		ctx:          ctx,
+		cancel:       cancelFunc,
 		sendStream:   r.stream,
 		notifyStream: r.stream,
 	}
@@ -172,6 +182,7 @@ func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message,
 }
 
 func (s *Server) serveBatch(ctx context.Context,
+	cancelFunc func(),
 	incoming []*jsonrpc.Message,
 	r *callResponder,
 ) error {
-- 
GitLab