diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go index 68c0517632d821a7c3920fd97be547f16c8b71d7..b5337c7876b70bd5886475b496c4218954bde30f 100644 --- a/contrib/codecs/websocket/codec.go +++ b/contrib/codecs/websocket/codec.go @@ -10,6 +10,7 @@ import ( "time" "gfx.cafe/open/websocket" + "golang.org/x/sync/semaphore" "gfx.cafe/open/jrpc/pkg/jjson" "gfx.cafe/open/jrpc/pkg/jsonrpc" @@ -24,7 +25,7 @@ type Codec struct { wrLock sync.Mutex decBuf json.RawMessage - decLock sync.Mutex + decLock *semaphore.Weighted i jsonrpc.PeerInfo } @@ -32,8 +33,9 @@ type Codec struct { func newWebsocketCodec(ctx context.Context, conn *websocket.Conn, host string, req *http.Request) *Codec { conn.SetReadLimit(WsMessageSizeLimit) c := &Codec{ - closed: make(chan struct{}), - conn: conn, + closed: make(chan struct{}), + conn: conn, + decLock: semaphore.NewWeighted(1), } c.i.Transport = "ws" // Fill in connection details. @@ -61,8 +63,10 @@ func heartbeat(ctx context.Context, c *websocket.Conn, d time.Duration) { } func (c *Codec) decodeSingleMessage(ctx context.Context) (*serverutil.Bundle, error) { - c.decLock.Lock() - defer c.decLock.Unlock() + if err := c.decLock.Acquire(ctx, 1); err != nil { + return nil, err + } + defer c.decLock.Release(1) c.decBuf = c.decBuf[:0] _, r, err := c.conn.Reader(ctx) if err != nil { diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go index 2a8fc3146c3ec3547f5a2b916ab7035d77c84245..4d6c688984ff4002834216bf8a9a34e4e3a38722 100644 --- a/contrib/extension/subscription/client_test.go +++ b/contrib/extension/subscription/client_test.go @@ -36,8 +36,10 @@ func TestSubscription(t *testing.T) { }) srv := server.NewServer(r) + defer srv.Shutdown(context.Background()) handler := codecs.WebsocketHandler(srv, []string{"*"}) httpSrv := httptest.NewServer(handler) + defer httpSrv.Close() wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") cl, err := UpgradeConn(jrpc.Dial(wsURL)) @@ -45,6 +47,11 @@ func TestSubscription(t *testing.T) { t.Error(err) return } + defer func() { + if err = cl.Close(); err != nil { + t.Error(err) + } + }() ch := make(chan int, count) sub, err := cl.Subscribe(context.Background(), "test", ch, nil) @@ -81,8 +88,10 @@ func TestUnsubscribeNoRead(t *testing.T) { }) srv := server.NewServer(r) + defer srv.Shutdown(context.Background()) handler := codecs.WebsocketHandler(srv, []string{"*"}) httpSrv := httptest.NewServer(handler) + defer httpSrv.Close() wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") cl, err := UpgradeConn(jrpc.Dial(wsURL)) @@ -90,6 +99,11 @@ func TestUnsubscribeNoRead(t *testing.T) { t.Error(err) return } + defer func() { + if err = cl.Close(); err != nil { + t.Error(err) + } + }() ch := make(chan int) sub, err := cl.Subscribe(context.Background(), "test", ch, nil) @@ -140,8 +154,10 @@ func TestWrapClient(t *testing.T) { }() }) srv := server.NewServer(r) + defer srv.Shutdown(context.Background()) handler := codecs.WebsocketHandler(srv, []string{"*"}) httpSrv := httptest.NewServer(handler) + defer httpSrv.Close() wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") cl, err := UpgradeConn(jrpc.Dial(wsURL)) @@ -149,6 +165,11 @@ func TestWrapClient(t *testing.T) { t.Error(err) return } + defer func() { + if err = cl.Close(); err != nil { + t.Error(err) + } + }() for i := 0; i < 10; i++ { var res string diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go index 210c7d97da62523d839c23edf1ad2076e4997e8a..b175b6fad3cd050371762f4bb25a6ee31506463e 100644 --- a/pkg/jrpctest/suites.go +++ b/pkg/jrpctest/suites.go @@ -30,6 +30,7 @@ func TestExecutor(sm ServerMaker) func(t *testing.T, c TestContext) { return func(t *testing.T, c TestContext) { server, dialer, cn := sm() defer cn() + defer server.Shutdown(context.Background()) client := dialer() defer client.Close() c(t, server, client) @@ -39,6 +40,7 @@ func BenchExecutor(sm ServerMaker) func(t *testing.B, c BenchContext) { return func(t *testing.B, c BenchContext) { server, dialer, cn := sm() defer cn() + defer server.Shutdown(context.Background()) client := dialer() defer client.Close() c(t, server, client) diff --git a/pkg/server/server.go b/pkg/server/server.go index eb1aad99f44b77e5ec6798fdc75b0b52e5eff375..dc9aff3e293d3af278b53896149ee7e92232c404 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -41,7 +41,7 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er ctx, cn := context.WithCancel(ctx) defer cn() - allErrs := []error{} + var allErrs []error var mu sync.Mutex wg := sync.WaitGroup{} err := func() error { @@ -69,7 +69,9 @@ func (s *Server) ServeCodec(ctx context.Context, remote jsonrpc.ReaderWriter) er } }() wg.Wait() - allErrs = append(allErrs, err) + if err != nil { + allErrs = append(allErrs, err) + } if len(allErrs) > 0 { return errors.Join(allErrs...) }