diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go index 47f2879fed7ad856f7bcfdb27b735affea394516..017e06d54caa466f7f8da8a9175b9f5394d633a8 100644 --- a/contrib/codecs/websocket/codec.go +++ b/contrib/codecs/websocket/codec.go @@ -2,28 +2,38 @@ package websocket import ( "context" + "io" "net/http" + "sync" "time" "gfx.cafe/open/websocket" + "github.com/go-faster/jx" + "github.com/goccy/go-json" - "gfx.cafe/open/jrpc/contrib/codecs/rdwr" "gfx.cafe/open/jrpc/pkg/codec" + "gfx.cafe/open/jrpc/pkg/serverutil" ) type Codec struct { - *rdwr.Codec - conn *websocket.Conn + closed chan struct{} + conn *websocket.Conn + + jx *jx.Encoder + wrLock sync.Mutex + + decBuf json.RawMessage + decLock sync.Mutex i codec.PeerInfo } func newWebsocketCodec(ctx context.Context, conn *websocket.Conn, host string, req http.Header) *Codec { conn.SetReadLimit(WsMessageSizeLimit) - netConn := websocket.NetConn(ctx, conn, websocket.MessageText) c := &Codec{ - Codec: rdwr.NewCodec(netConn, netConn), - conn: conn, + closed: make(chan struct{}), + conn: conn, + jx: jx.NewStreamingEncoder(nil, 4096), } c.i.Transport = "ws" // Fill in connection details. @@ -62,13 +72,62 @@ func heartbeat(ctx context.Context, c *websocket.Conn, d time.Duration) { } } +func (c *Codec) decodeSingleMessage(ctx context.Context) (*serverutil.Bundle, error) { + c.decLock.Lock() + defer c.decLock.Unlock() + c.decBuf = c.decBuf[:0] + _, r, err := c.conn.Reader(ctx) + if err != nil { + return nil, err + } + defer io.Copy(io.Discard, r) + err = json.NewDecoder(r).DecodeContext(ctx, &c.decBuf) + if err != nil { + return nil, err + } + return serverutil.ParseBundle(c.decBuf), nil +} + +func (c *Codec) ReadBatch(ctx context.Context) ([]*codec.Message, bool, error) { + ans, err := c.decodeSingleMessage(ctx) + if err != nil { + return nil, false, err + } + return ans.Messages, ans.Batch, nil +} + +func (c *Codec) Send(fn func(e *jx.Encoder) error) error { + c.wrLock.Lock() + defer c.wrLock.Unlock() + + wr, err := c.conn.Writer(context.Background(), websocket.MessageText) + if err != nil { + return err + } + c.jx.ResetWriter(wr) + if err = fn(c.jx); err != nil { + return err + } + if err = c.jx.Close(); err != nil { + return err + } + return wr.Close() +} + func (c *Codec) PeerInfo() codec.PeerInfo { return c.i } +func (c *Codec) Closed() <-chan struct{} { + return c.closed +} + func (c *Codec) Close() error { - if err := c.Codec.Close(); err != nil { - return err + select { + case <-c.closed: + return nil + default: + close(c.closed) } return c.conn.Close(websocket.StatusNormalClosure, "") } @@ -76,3 +135,5 @@ func (c *Codec) Close() error { func (c *Codec) RemoteAddr() string { return c.i.RemoteAddr } + +var _ codec.ReaderWriter = (*Codec)(nil) diff --git a/contrib/codecs/websocket/handler.go b/contrib/codecs/websocket/handler.go index c2f9d27f68c76163eb9bcfd6981539dded084cad..7eaf816d597ad22f47a4eade6a71dd32ecf3f41b 100644 --- a/contrib/codecs/websocket/handler.go +++ b/contrib/codecs/websocket/handler.go @@ -1,7 +1,6 @@ package websocket import ( - "log/slog" "net/http" "gfx.cafe/open/websocket" @@ -26,7 +25,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { c := newWebsocketCodec(r.Context(), conn, "", r.Header) err = s.Server.ServeCodec(r.Context(), c) if err != nil { - slog.Error("codec err", "error", err) + // slog.Error("codec err", "error", err) } } @@ -47,7 +46,7 @@ func WebsocketHandler(s *server.Server, allowedOrigins []string) http.Handler { codec := newWebsocketCodec(r.Context(), conn, r.Host, r.Header) err = s.ServeCodec(r.Context(), codec) if err != nil { - // slog.Error("codec err", "error", err) + // slog.Error("codec err", "error", err) } }) } diff --git a/pkg/clientutil/idreply.go b/pkg/clientutil/idreply.go index b9520d3b50bb5bda9fdab660d15e2e1ef1882662..bfb17d6198ca2eb744a714cd61779dcfecfbd03e 100644 --- a/pkg/clientutil/idreply.go +++ b/pkg/clientutil/idreply.go @@ -31,26 +31,41 @@ func (i *IdReply) NextId() *codec.ID { return codec.NewNumberIDPtr(i.id.Add(1)) } -func (i *IdReply) makeOrTake(id []byte) chan msgOrError { +func (i *IdReply) make(id []byte) <-chan msgOrError { i.mu.Lock() defer i.mu.Unlock() - if val, ok := i.chs[string(id)]; ok { - delete(i.chs, string(id)) - return val - } - o := make(chan msgOrError) - i.chs[string(id)] = o - return o + ch := make(chan msgOrError, 1) + i.chs[string(id)] = ch + return ch +} + +func (i *IdReply) take(id []byte) chan<- msgOrError { + i.mu.Lock() + defer i.mu.Unlock() + ch := i.chs[string(id)] + delete(i.chs, string(id)) + return ch +} + +func (i *IdReply) remove(id []byte) { + i.mu.Lock() + defer i.mu.Unlock() + delete(i.chs, string(id)) } func (i *IdReply) Resolve(id []byte, msg json.RawMessage, err error) { + ch := i.take(id) + if ch == nil { + return + } + if err != nil { - i.makeOrTake(id) <- msgOrError{ + ch <- msgOrError{ err: err, } return } - i.makeOrTake(id) <- msgOrError{ + ch <- msgOrError{ msg: msg, } @@ -58,9 +73,10 @@ func (i *IdReply) Resolve(id []byte, msg json.RawMessage, err error) { func (i *IdReply) Ask(ctx context.Context, id []byte) (json.RawMessage, error) { select { - case resp := <-i.makeOrTake(id): + case resp := <-i.make(id): return resp.msg, resp.err case <-ctx.Done(): + i.remove(id) return nil, ctx.Err() } } diff --git a/pkg/codec/wire.go b/pkg/codec/wire.go index 7a57692feca297ae4ab97bc60d197b93906a35f5..4a9c19cde29b3cd7f22253ddac3fd39370d2b7ef 100644 --- a/pkg/codec/wire.go +++ b/pkg/codec/wire.go @@ -6,7 +6,7 @@ import ( "reflect" "strconv" - json "github.com/goccy/go-json" + "github.com/goccy/go-json" ) // Version represents a JSON-RPC version. @@ -130,7 +130,7 @@ func (id ID) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (id *ID) UnmarshalJSON(data []byte) error { - *id = data + *id = bytes.Clone(data) // now validate if id.IsNull() { return nil diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go index 9adb7da7cfbb2ed171f4734a2391d4cb0818bf6f..ccb01c820b574a62c4dd1884c505e2d80fa2517a 100644 --- a/pkg/jrpctest/suites.go +++ b/pkg/jrpctest/suites.go @@ -220,6 +220,31 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { } wg.Wait() }) + + makeTest("big", func(t *testing.T, server *server.Server, client codec.Conn) { + var ( + wg sync.WaitGroup + nreqs = 2 + ncallers = 10 + ) + wg.Add(ncallers) + // create a bunch of parallel requests with lots of data to see if any buffers are overwritten causing a failure + for i := 0; i < ncallers; i++ { + go func() { + defer wg.Done() + + for j := 0; j < nreqs; j++ { + if err := codec.CallInto(context.Background(), client, nil, "large_largeResp"); err != nil { + t.Error(err) + return + } + } + }() + } + + wg.Wait() + }) + makeTest("", func(t *testing.T, server *server.Server, client codec.Conn) { }) } diff --git a/pkg/server/server.go b/pkg/server/server.go index de6878b24251a38605846c56e100dcaaa0db4e50..e6df14654b39163c08a82d725d10eb7e336168a0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -43,13 +43,19 @@ func (s *Server) ServeCodec(ctx context.Context, remote codec.ReaderWriter) erro // read messages from the stream synchronously incoming, batch, err := remote.ReadBatch(ctx) if err != nil { - errch <- err + select { + case errch <- err: + case <-ctx.Done(): + } return } go func() { err = s.serveBatch(ctx, incoming, batch, remote, responder) if err != nil { - errch <- err + select { + case errch <- err: + case <-ctx.Done(): + } return } }()