From 06d19b75f3ce0cd4c3790bb31c544945c5c13c86 Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Fri, 2 Feb 2024 03:16:21 -0600
Subject: [PATCH] ok
---
contrib/codecs/websocket/codec.go | 23 +++++++++++++++++++----
contrib/codecs/websocket/handler.go | 2 +-
pkg/server/responsewriter.go | 4 ++++
pkg/server/server.go | 27 +++++++++++++++++++--------
4 files changed, 43 insertions(+), 13 deletions(-)
diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go
index b5337c7..29e4744 100644
--- a/contrib/codecs/websocket/codec.go
+++ b/contrib/codecs/websocket/codec.go
@@ -20,6 +20,8 @@ import (
type Codec struct {
closed chan struct{}
conn *websocket.Conn
+ closer func()
+ ctx context.Context
currentFrame io.WriteCloser
wrLock sync.Mutex
@@ -32,14 +34,20 @@ type Codec struct {
func newWebsocketCodec(ctx context.Context, conn *websocket.Conn, host string, req *http.Request) *Codec {
conn.SetReadLimit(WsMessageSizeLimit)
+
+ ctx, cn := context.WithCancel(ctx)
c := &Codec{
closed: make(chan struct{}),
conn: conn,
decLock: semaphore.NewWeighted(1),
+ ctx: ctx,
+ }
+ c.closer = func() {
+ cn()
}
c.i.Transport = "ws"
// Fill in connection details.
- c.i.HTTP = req.Clone(req.Context())
+ c.i.HTTP = req.Clone(ctx)
// Start pinger.
go heartbeat(ctx, conn, WsPingInterval)
return c
@@ -92,20 +100,26 @@ func (c *Codec) Write(p []byte) (n int, err error) {
c.wrLock.Lock()
defer c.wrLock.Unlock()
if c.currentFrame == nil {
- wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
+ wr, err := c.conn.Writer(c.ctx, websocket.MessageText)
if err != nil {
+ c.Close()
return 0, err
}
c.currentFrame = wr
}
- return c.currentFrame.Write(p)
+
+ n, err = c.currentFrame.Write(p)
+ if err != nil {
+ c.Close()
+ }
+ return
}
func (c *Codec) Flush() error {
c.wrLock.Lock()
defer c.wrLock.Unlock()
if c.currentFrame == nil {
- wr, err := c.conn.Writer(context.Background(), websocket.MessageText)
+ wr, err := c.conn.Writer(c.ctx, websocket.MessageText)
if err != nil {
return err
}
@@ -132,6 +146,7 @@ func (c *Codec) Close() error {
case <-c.closed:
return nil
default:
+ c.closer()
close(c.closed)
}
return c.conn.Close(websocket.StatusNormalClosure, "")
diff --git a/contrib/codecs/websocket/handler.go b/contrib/codecs/websocket/handler.go
index 2458a85..b3ef650 100644
--- a/contrib/codecs/websocket/handler.go
+++ b/contrib/codecs/websocket/handler.go
@@ -25,7 +25,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := newWebsocketCodec(r.Context(), conn, "", r)
err = s.Server.ServeCodec(r.Context(), c)
if err != nil {
- // slog.Error("codec err", "error", err)
+ //slog.Error("codec err", "error", err)
}
}
diff --git a/pkg/server/responsewriter.go b/pkg/server/responsewriter.go
index fb7676c..0b665d0 100644
--- a/pkg/server/responsewriter.go
+++ b/pkg/server/responsewriter.go
@@ -12,6 +12,8 @@ var _ jsonrpc.ResponseWriter = (*streamingRespWriter)(nil)
type streamingRespWriter struct {
// this should be the same context as the request
ctx context.Context
+ // if there is an unrecoverable error, this should be used to immediately terminate the codec
+ cancel func()
// the stream that Send will write to
sendStream jsonrpc.MessageStreamer
// the stream that Notify will write to
@@ -73,6 +75,7 @@ func (c *streamingRespWriter) Send(v any, e error) (err error) {
}
defer msg.Close()
if err = send(ce, msg); err != nil {
+ c.cancel()
return err
}
return nil
@@ -89,6 +92,7 @@ func (c *streamingRespWriter) Notify(method string, v any) error {
dat: v,
}, msg)
if err != nil {
+ c.cancel()
return err
}
return nil
diff --git a/pkg/server/server.go b/pkg/server/server.go
index 4954fdf..393b0c2 100644
--- a/pkg/server/server.go
+++ b/pkg/server/server.go
@@ -51,10 +51,12 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
ctx = ContextWithMessageStream(ctx, stream)
ctx, cn := context.WithCancel(ctx)
defer cn()
- errCh := make(chan error)
+ errCh := make(chan error, 1)
batches := make(chan serverutil.Bundle, 1)
go func() {
- defer close(batches)
+ defer func() {
+ close(batches)
+ }()
for {
// read messages from the stream synchronously
incoming, batch, err := remote.ReadBatch(ctx)
@@ -62,7 +64,10 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
// if its not context canceled, aka our graceful closure, we error, otherwise we only return
// in both cases we close the batches channel. this error will then immediately return.
if !errors.Is(err, context.Canceled) {
- errCh <- err
+ select {
+ case errCh <- err:
+ default:
+ }
}
return
}
@@ -75,6 +80,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
wg := sync.WaitGroup{}
// this errgroup controls the max concurrent requests per codec
egg := errgroup.Group{}
+ egg.SetLimit(4)
for batch := range batches {
incoming, batch := batch.Messages, batch.Batch
wg.Add(1)
@@ -84,7 +90,9 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
stream: stream,
}
egg.Go(func() error {
- return s.serve(ctx, incoming, responder)
+ return s.serve(ctx, func() {
+ cn()
+ }, incoming, responder)
})
}
go func() {
@@ -95,6 +103,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er
}
errCh <- nil
}()
+
select {
case err := <-errCh:
return err
@@ -106,23 +115,24 @@ func (s *Server) Shutdown(ctx context.Context) error {
return nil
}
-func (s *Server) serve(ctx context.Context,
+func (s *Server) serve(ctx context.Context, cancelFunc func(),
incoming []*jsonrpc.Message,
r *callResponder,
) error {
if r.batch {
- return s.serveBatch(ctx, incoming, r)
+ return s.serveBatch(ctx, cancelFunc, incoming, r)
} else {
- return s.serveSingle(ctx, incoming[0], r)
+ return s.serveSingle(ctx, cancelFunc, incoming[0], r)
}
}
-func (s *Server) serveSingle(ctx context.Context,
+func (s *Server) serveSingle(ctx context.Context, cancelFunc func(),
incoming *jsonrpc.Message,
r *callResponder,
) error {
rw := &streamingRespWriter{
ctx: ctx,
+ cancel: cancelFunc,
sendStream: r.stream,
notifyStream: r.stream,
}
@@ -172,6 +182,7 @@ func produceOutputMessage(inputMessage *jsonrpc.Message) (out *jsonrpc.Message,
}
func (s *Server) serveBatch(ctx context.Context,
+ cancelFunc func(),
incoming []*jsonrpc.Message,
r *callResponder,
) error {
--
GitLab