diff --git a/pkg/jsonrpc/message.go b/pkg/jsonrpc/message.go index 67cf08269a2e2337a634f7fad03fe524548fdcb8..13629b3d79855f31bc7c2426e84e67fbe8a5bc56 100644 --- a/pkg/jsonrpc/message.go +++ b/pkg/jsonrpc/message.go @@ -3,31 +3,46 @@ package jsonrpc import ( "encoding/json" "io" + + "golang.org/x/net/context" + "golang.org/x/sync/semaphore" ) // MessageStream is a writer used to write jsonrpc message to a stream type MessageStream struct { - w io.Writer + w io.Writer + mu *semaphore.Weighted } func NewStream(w io.Writer) *MessageStream { return &MessageStream{ - w: w, + w: w, + mu: semaphore.NewWeighted(1), } } -func (m *MessageStream) NewMessage() (*MessageWriter, error) { - _, err := m.w.Write([]byte(`{"jsonrpc":"2.0"`)) +// NewMessage starts a new message and acquires the write lock. +// to free the write lock, you must call *MessageWriter.Close() +// the lock MUST be closed if and only if err != nil +func (m *MessageStream) NewMessage(ctx context.Context) (*MessageWriter, error) { + err := m.mu.Acquire(ctx, 1) + if err != nil { + return nil, err + } + _, err = m.w.Write([]byte(`{"jsonrpc":"2.0"`)) if err != nil { + m.mu.Release(1) return nil, err } return &MessageWriter{ - w: m.w, + w: m.w, + mu: m.mu, }, nil } type MessageWriter struct { - w io.Writer + w io.Writer + mu *semaphore.Weighted } func (m *MessageWriter) Field(name string, value json.RawMessage) error { @@ -51,6 +66,8 @@ func (m *MessageWriter) Result() (io.Writer, error) { return &ResultWriter{w: m.w}, nil } +// close must be called when you are done writing the message. +// it releases the write lock func (m *MessageWriter) Close() error { _, err := m.w.Write([]byte("}")) return err diff --git a/pkg/server/server.go b/pkg/server/server.go index 889769186cd3f3c2ac0d55a77fcd0e857e452b02..ce5c3f9d91032ead3a95f9f5e5312ae27726b7cd 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -293,7 +293,7 @@ type callEnv struct { func (c *callResponder) send(ctx context.Context, env *callEnv) (err error) { w := c.remote - s, err := jsonrpc.NewStream(w).NewMessage() + s, err := jsonrpc.NewStream(w).NewMessage(ctx) if err != nil { return err }