diff --git a/codec/http/client.go b/codec/http/client.go index 2f3594c6ad9a905862eec07a18a970543400356d..e89799c84299d91c2639746ce2a5469952f36049 100644 --- a/codec/http/client.go +++ b/codec/http/client.go @@ -17,17 +17,17 @@ package jrpc import ( + "bytes" "context" "encoding/json" "errors" "fmt" - "net/url" - "reflect" + "net/http" "sync/atomic" "time" - jsoniter "github.com/json-iterator/go" - "tuxpa.in/a/zlog/log" + "gfx.cafe/open/jrpc" + "gfx.cafe/open/jrpc/codec" ) var ( @@ -44,550 +44,98 @@ const ( subscribeTimeout = 5 * time.Second // overall timeout eth_subscribe, rpc_modules calls ) -var _ SubscriptionConn = (*Client)(nil) - // Client represents a connection to an RPC server. type Client struct { - isHTTP bool // connection type: http, ws or ipc - - idCounter uint64 - - r Handler - // This function, if non-nil, is called when the connection is lost. - reconnectFunc reconnectFunc - - // writeConn is used for writing to the connection on the caller's goroutine. It should - // only be accessed outside of dispatch, with the write lock held. The write lock is - // taken by sending on reqInit and released by sending on reqSent. - writeConn JsonWriter - - // for dispatch - close chan struct{} - closing chan struct{} // closed when client is quitting - didClose chan struct{} // closed when client quits - reconnected chan ServerCodec // where write/reconnect sends the new connection - readOp chan readOp // read messages - readErr chan error // errors from read - reqInit chan *requestOp // register response IDs, takes write lock - reqSent chan error // signals write completion, releases write lock - reqTimeout chan *requestOp // removes response IDs when call timeout expires -} - -type reconnectFunc func(ctx context.Context) (ServerCodec, error) - -type clientContextKey struct{} - -type clientConn struct { - codec ServerCodec - handler *handler -} - -func (c *Client) newClientConn(conn ServerCodec) *clientConn { - ctx := context.Background() - ctx = context.WithValue(ctx, clientContextKey{}, c) - ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.PeerInfo()) - handler := newHandler(ctx, conn, c.r) - return &clientConn{conn, handler} -} + remote string + c *http.Client -func (cc *clientConn) close(err error, inflightReq *requestOp) { - cc.handler.close(err, inflightReq) - cc.codec.Close() + id atomic.Int64 } -type readOp struct { - msgs []*jsonrpcMessage - batch bool +func Dial(ctx context.Context, client *http.Client, target string) (*Client, error) { + return &Client{remote: target, c: client}, nil } -type requestOp struct { - ids []json.RawMessage - err error - resp chan *jsonrpcMessage // receives up to len(ids) responses - - sub *ClientSubscription -} - -func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { - select { - case <-ctx.Done(): - // Send the timeout to dispatch so it can remove the request IDs. - if !c.isHTTP { - select { - case c.reqTimeout <- op: - case <-c.closing: - } - } - return nil, ctx.Err() - case resp := <-op.resp: - return resp, op.err - } -} - -// Dial creates a new client for the given URL. -// -// The currently supported URL schemes are "http", "https", "ws" and "wss". If rawurl is a -// file name with no URL scheme, a local socket connection is established using UNIX -// domain sockets on supported platforms and named pipes on Windows. If you want to -// configure transport options, use DialHTTP, DialWebsocket or DialIPC instead. -// -// For websocket connections, the origin is set to the local host name. -// -// The client reconnects automatically if the connection is lost. -func Dial(rawurl string) (*Client, error) { - return DialContext(context.Background(), rawurl) -} - -// DialContext creates a new RPC client, just like Dial. -// -// The context is used to cancel or time out the initial connection establishment. It does -// not affect subsequent interactions with the client. -func DialContext(ctx context.Context, rawurl string) (*Client, error) { - u, err := url.Parse(rawurl) - if err != nil { - return nil, err - } - switch u.Scheme { - case "http", "https": - return DialHTTP(rawurl) - case "ws", "wss": - return DialWebsocket(ctx, rawurl, "") - case "tcp": - return DialTCP(ctx, rawurl) - case "stdio": - return DialStdIO(ctx) - case "": - return DialIPC(ctx, rawurl) - default: - return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) - } -} - -// ClientFromContext retrieves the client from the context, if any. This can be used to perform -// 'reverse calls' in a handler method. -func ClientFromContext(ctx context.Context) (*Client, bool) { - client, ok := ctx.Value(clientContextKey{}).(*Client) - return client, ok -} - -func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { - conn, err := connect(initctx) - if err != nil { - return nil, err - } - c := initClient(conn, HandlerFunc(func(w ResponseWriter, r *Request) {})) - c.reconnectFunc = connect - return c, nil -} - -func initClient(conn ServerCodec, r Handler) *Client { - _, isHTTP := conn.(*httpConn) - c := &Client{ - r: r, - isHTTP: isHTTP, - writeConn: conn, - close: make(chan struct{}), - closing: make(chan struct{}), - didClose: make(chan struct{}), - reconnected: make(chan ServerCodec), - readOp: make(chan readOp), - readErr: make(chan error), - reqInit: make(chan *requestOp), - reqSent: make(chan error, 1), - reqTimeout: make(chan *requestOp), - } - if !isHTTP { - go c.dispatch(conn) - } - return c -} - -func (c *Client) nextID() *ID { - id := atomic.AddUint64(&c.idCounter, 1) - return NewNumberIDPtr(int64(id)) -} - -// SupportedModules calls the rpc_modules method, retrieving the list of -// APIs that are available on the server. -func (c *Client) SupportedModules() (map[string]string, error) { - var result map[string]string - ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout) - defer cancel() - err := c.Call(ctx, &result, "rpc_modules") - return result, err -} - -// Close closes the client, aborting any in-flight requests. -func (c *Client) Close() error { - if c.isHTTP { - return nil - } - select { - case c.close <- struct{}{}: - <-c.didClose - case <-c.didClose: - } - return nil -} - -// SetHeader adds a custom HTTP header to the client's requests. -// This method only works for clients using HTTP, it doesn't have -// any effect for clients using another transport. -func (c *Client) SetHeader(key, value string) { - if !c.isHTTP { - return - } - conn := c.writeConn.(*httpConn) - conn.mu.Lock() - conn.headers.Set(key, value) - conn.mu.Unlock() -} - -func (c *Client) call(ctx context.Context, result any, msg *jsonrpcMessage) error { - var err error - op := &requestOp{ids: []json.RawMessage{msg.ID.RawMessage()}, resp: make(chan *jsonrpcMessage, 1)} - - if c.isHTTP { - err = c.sendHTTP(ctx, op, msg) - } else { - err = c.send(ctx, op, msg) - } +func (c *Client) Do(ctx context.Context, result any, method string, params any) error { + req := jrpc.NewRequestInt(ctx, int(c.id.Add(1)), method, params) + dat, err := req.MarshalJSON() if err != nil { return err } - // dispatch has accepted the request and will close the channel when it quits. - resp, err := op.wait(ctx, c) + resp, err := c.c.Post(c.remote, "application/json", bytes.NewBuffer(dat)) if err != nil { return err } - switch { - case resp.Error != nil: - return resp.Error - case len(resp.Result) == 0: - return ErrNoResult - case result == nil: - return nil - default: - return json.Unmarshal(resp.Result, &result) + defer resp.Body.Close() + if result != nil { + json.NewDecoder(resp.Body).Decode(&result) } + return nil } -// Do performs a JSON-RPC call with the given arguments and unmarshals into -// result if no error occurred. -// -// The result must be a pointer so that package json can unmarshal into it. You -// can also pass nil, in which case the result is ignored. -func (c *Client) Do(ctx context.Context, result any, method string, params any) error { - if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr { - return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result) - } - msg, err := c.newMessageP(method, params) +func (c *Client) Notify(ctx context.Context, result any, method string, params any) error { + req := jrpc.NewRequestInt(ctx, int(c.id.Add(1)), method, params) + dat, err := req.MarshalJSON() if err != nil { return err } - if ctx == nil { - ctx = context.TODO() - } - return c.call(ctx, result, msg) -} - -// Call calls Do, except accepts variadic parameters -func (c *Client) Call(ctx context.Context, result any, method string, args ...any) error { - return c.Do(ctx, result, method, args) -} - -// BatchCall sends all given requests as a single batch and waits for the server -// to return a response for all of them. -// -// In contrast to Call, BatchCall only returns I/O errors. Any error specific to -// a request is reported through the Error field of the corresponding BatchElem. -// -// Note that batch calls may not be executed atomically on the server side. -func (c *Client) BatchCall(ctx context.Context, b ...BatchElem) error { - var ( - msgs = make([]*jsonrpcMessage, len(b)) - byID = make(map[string]int, len(b)) - ) - - if ctx == nil { - ctx = context.TODO() - } - op := &requestOp{ - ids: make([]json.RawMessage, len(b)), - resp: make(chan *jsonrpcMessage, len(b)), - } - for i, elem := range b { - msg, err := c.newMessageP(elem.Method, elem.Args) - if err != nil { - return err - } - msgs[i] = msg - op.ids[i] = msg.ID.RawMessage() - byID[string(msg.ID.RawMessage())] = i - } - - var err error - if c.isHTTP { - err = c.sendBatchHTTP(ctx, op, msgs) - } else { - err = c.send(ctx, op, msgs) - } - - // Wait for all responses to come back. - for n := 0; n < len(b) && err == nil; n++ { - var resp *jsonrpcMessage - resp, err = op.wait(ctx, c) - if err != nil { - break - } - // Find the element corresponding to this response. - // The element is guaranteed to be present because dispatch - // only sends valid IDs to our channel. - elem := &b[byID[string(resp.ID.RawMessage())]] - if resp.Error != nil { - elem.Error = resp.Error - continue - } - if len(resp.Result) == 0 { - elem.Error = ErrNoResult - continue - } - elem.Error = json.Unmarshal(resp.Result, elem.Result) - } - - return err -} - -func (c *Client) Notify(ctx context.Context, method string, args ...any) error { - op := new(requestOp) - msg, err := c.newMessageP(method, args) + _, err = c.c.Post(c.remote, "application/json", bytes.NewBuffer(dat)) if err != nil { return err } - if ctx == nil { - ctx = context.TODO() - } - msg.ID = nil - if c.isHTTP { - return c.sendHTTP(ctx, op, msg) - } - return c.send(ctx, op, msg) -} - -func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { - // Check type of channel first. - chanVal := reflect.ValueOf(channel) - if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { - panic("first argument to Subscribe must be a writable channel") - } - if chanVal.IsNil() { - panic("channel given to Subscribe must not be nil") - } - if c.isHTTP { - return nil, ErrNotificationsUnsupported - } - msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...) - if err != nil { - return nil, err - } - op := &requestOp{ - ids: []json.RawMessage{msg.ID.RawMessage()}, - resp: make(chan *jsonrpcMessage), - sub: newClientSubscription(c, namespace, chanVal), - } - - // Send the subscription request. - // The arrival and validity of the response is signaled on sub.quit. - if err := c.send(ctx, op, msg); err != nil { - return nil, err - } - if _, err := op.wait(ctx, c); err != nil { - return nil, err - } - return op.sub, nil + return err } -func (c *Client) newMessage(method string, paramsIn ...any) (*jsonrpcMessage, error) { - return c.newMessageP(method, paramsIn) -} -func (c *Client) newMessageP(method string, paramIn any) (*jsonrpcMessage, error) { - msg := &jsonrpcMessage{ID: c.nextID(), Method: method} - if paramIn != nil { // prevent sending "params":null - if cast, ok := paramIn.(json.RawMessage); ok { - msg.Params = cast +func (c *Client) BatchCall(ctx context.Context, b ...*jrpc.BatchElem) error { + reqs := make([]*jrpc.Request, len(b)) + ids := make([]int, 0, len(b)) + for _, v := range b { + if v.IsNotification { + reqs = append(reqs, jrpc.NewRequest(ctx, "", v.Method, v.Params)) } else { - var err error - if msg.Params, err = jsoniter.Marshal(paramIn); err != nil { - return nil, err - } + id := int(c.id.Add(1)) + ids = append(ids, id) + reqs = append(reqs, jrpc.NewRequestInt(ctx, id, v.Method, v.Params)) } } - return msg, nil -} - -// send registers op with the dispatch loop, then sends msg on the connection. -// if sending fails, op is deregistered. -func (c *Client) send(ctx context.Context, op *requestOp, msg any) error { - select { - case c.reqInit <- op: - err := c.write(ctx, msg, false) - c.reqSent <- err + dat, err := json.Marshal(b) + if err != nil { return err - case <-ctx.Done(): - // This can happen if the client is overloaded or unable to keep up with - // subscription notifications. - return ctx.Err() - case <-c.closing: - return ErrClientQuit } -} - -func (c *Client) write(ctx context.Context, msg any, retry bool) error { - if c.writeConn == nil { - // The previous write failed. Try to establish a new connection. - // time.Sleep(500 * time.Millisecond) - err := c.reconnect(ctx) - if err != nil { - return err - } - } - err := c.writeConn.WriteJSON(ctx, msg) + resp, err := c.c.Post(c.remote, "application/json", bytes.NewBuffer(dat)) if err != nil { - c.writeConn = nil - if !retry { - return c.write(ctx, msg, true) - } - } - return err -} - -func (c *Client) reconnect(ctx context.Context) error { - if c.reconnectFunc == nil { - return errDead + return err } + defer resp.Body.Close() - if _, ok := ctx.Deadline(); !ok { - var cancel func() - ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout) - defer cancel() - } - newconn, err := c.reconnectFunc(ctx) + msgs := []*codec.Message{} + err = json.NewDecoder(resp.Body).Decode(&msgs) if err != nil { - log.Trace().Err(err).Msg("RPC client reconnect failed") return err } - select { - case c.reconnected <- newconn: - c.writeConn = newconn - return nil - case <-c.didClose: - newconn.Close() - return ErrClientQuit + answers := map[int]*codec.Message{} + for _, v := range msgs { + answers[v.ID.Number()] = v } -} -// dispatch is the main loop of the client. -// It sends read messages to waiting calls to Call and BatchCall -// and subscription notifications to registered subscriptions. -func (c *Client) dispatch(codec ServerCodec) { - var ( - lastOp *requestOp // tracks last send operation - reqInitLock = c.reqInit // nil while the send lock is held - conn = c.newClientConn(codec) - reading = true - ) - defer func() { - close(c.closing) - if reading { - conn.close(ErrClientQuit, nil) - c.drainRead() + for i := range ids { + idx := i + ans, ok := answers[i] + if !ok { + b[idx].Error = fmt.Errorf("No response found") + continue } - close(c.didClose) - }() - - // Spawn the initial read loop. - go c.read(codec) - for { - select { - case <-c.close: - return - // Read path: - case op := <-c.readOp: - if op.batch { - conn.handler.handleBatch(op.msgs) - } else { - conn.handler.handleMsg(op.msgs[0]) - } - - case err := <-c.readErr: - conn.handler.log.Debug().Err(err).Msg("RPC connection read error") - conn.close(err, lastOp) - reading = false - - // Reconnect: - case newcodec := <-c.reconnected: - log.Debug().Bool("reading", reading).Str("conn", newcodec.RemoteAddr()).Msg("RPC client reconnected") - if reading { - // Wait for the previous read loop to exit. This is a rare case which - // happens if this loop isn't notified in time after the connection breaks. - // In those cases the caller will notice first and reconnect. Closing the - // handler terminates all waiting requests (closing op.resp) except for - // lastOp, which will be transferred to the new handler. - conn.close(errClientReconnected, lastOp) - c.drainRead() - } - go c.read(newcodec) - reading = true - conn = c.newClientConn(newcodec) - // Re-register the in-flight request on the new handler - // because that's where it will be sent. - conn.handler.addRequestOp(lastOp) - - // Send path: - case op := <-reqInitLock: - // Stop listening for further requests until the current one has been sent. - reqInitLock = nil - lastOp = op - conn.handler.addRequestOp(op) - - case err := <-c.reqSent: - if err != nil { - // Remove response handlers for the last send. When the read loop - // goes down, it will signal all other current operations. - conn.handler.removeRequestOp(lastOp) - } - // Let the next request in. - reqInitLock = c.reqInit - lastOp = nil - - case op := <-c.reqTimeout: - conn.handler.removeRequestOp(op) + if b[idx].Result == nil { + continue } - } -} - -// drainRead drops read messages until an error occurs. -func (c *Client) drainRead() { - for { - select { - case <-c.readOp: - case <-c.readErr: - return + err = json.Unmarshal(ans.Result, b[idx].Result) + if err != nil { + b[idx].Error = err } } + return nil } -// read decodes RPC messages from a codec, feeding them into dispatch. -func (c *Client) read(codec ServerCodec) { - for { - msgs, batch, err := codec.ReadBatch() - if _, ok := err.(*json.SyntaxError); ok { - codec.WriteJSON(context.Background(), errorMessage(&parseError{err.Error()})) - } - if err != nil { - c.readErr <- err - return - } - c.readOp <- readOp{msgs, batch} - } +func (c *Client) Close() error { + return nil } diff --git a/codec/http/http.go b/codec/http/http.go index acc542079d24144d3317f3ab53e77f5a9e5add57..a89a967b982fac3b5e54c492ebc10a084bb87606 100644 --- a/codec/http/http.go +++ b/codec/http/http.go @@ -17,20 +17,19 @@ package jrpc import ( + "bytes" "context" "encoding/base64" - "errors" + "encoding/json" "fmt" "io" - "mime" "net/http" "net/url" "time" + "gfx.cafe/open/jrpc" "gfx.cafe/open/jrpc/codec" "gfx.cafe/util/go/bufpool" - - json "github.com/goccy/go-json" ) const ( @@ -79,102 +78,27 @@ var DefaultHTTPTimeouts = HTTPTimeouts{ } // httpServerConn turns a HTTP connection into a Conn. -type httpServerConn struct { - io.Reader - io.Writer - - jc codec.ReaderWriter - +type requestCodec struct { r *http.Request w http.ResponseWriter - pi codec.PeerInfo -} - -func newHTTPServerConn(r *http.Request, w http.ResponseWriter, pi codec.PeerInfo) codec.ReaderWriter { - c := &httpServerConn{Writer: w, r: r, pi: pi} - // if the request is a GET request, and the body is empty, we turn the request into fake json rpc request, see below - // https://www.jsonrpc.org/historical/json-rpc-over-http.html#encoded-parameters - // we however allow for non base64 encoded parameters to be passed - if r.Method == http.MethodGet { - // default id 1 - id := `1` - id_up := r.URL.Query().Get("id") - if id_up != "" { - id = id_up - } - method_up := r.URL.Query().Get("method") - params, _ := url.QueryUnescape(r.URL.Query().Get("params")) - param := []byte(params) - if pb, err := base64.URLEncoding.DecodeString(params); err == nil { - param = pb - } - buf := bufpool.GetStd() - json.NewEncoder(buf).Encode(jsonrpcMessage{ - ID: NewStringIDPtr(id), - Method: method_up, - Params: param, - }) - c.Reader = buf - } else { - // it's a post request or whatever, so just process it like normal - c.Reader = io.LimitReader(r.Body, maxRequestContentLength) - } - c.jc = NewCodec(c) - return c -} - -func (c *httpServerConn) PeerInfo() PeerInfo { - return c.pi -} - -func (c *httpServerConn) ReadBatch() (messages []*jsonrpcMessage, batch bool, err error) { - return c.jc.ReadBatch() -} - -func (c *httpServerConn) WriteJSON(ctx context.Context, v any) error { - return c.jc.WriteJSON(ctx, v) -} - -func (c *httpServerConn) Close() error { - return nil -} - -// Closed returns a channel which will be closed when Close is called -func (c *httpServerConn) Closed() <-chan any { - return c.jc.Closed() -} + ctx context.Context + cn func() -// RemoteAddr returns the peer address of the underlying connection. -func (t *httpServerConn) RemoteAddr() string { - return t.PeerInfo().RemoteAddr + requestBuffer *bytes.Buffer + pi codec.PeerInfo } -// SetWriteDeadline does nothing and always returns nil. -func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil } - -// ServeHTTP serves JSON-RPC requests over HTTP. -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Permit dumb empty requests for remote health-checks (AWS) - if r.Method == http.MethodGet && r.ContentLength == 0 && r.URL.RawQuery == "" { - w.WriteHeader(http.StatusOK) - return - } - if code, err := validateRequest(r); err != nil { - http.Error(w, err.Error(), code) - return - } - +func NewRequestCodec(r *http.Request, w http.ResponseWriter) *requestCodec { // Create request-scoped context. - connInfo := PeerInfo{ + connInfo := codec.PeerInfo{ Transport: "http", RemoteAddr: r.RemoteAddr, - HTTP: HttpInfo{ - Version: r.Proto, - UserAgent: r.UserAgent(), - Host: r.Host, - Headers: r.Header.Clone(), - WriteHeaders: w.Header(), + HTTP: codec.HttpInfo{ + Version: r.Proto, + UserAgent: r.UserAgent(), + Host: r.Host, + Headers: r.Header.Clone(), }, } connInfo.HTTP.Version = r.Proto @@ -191,44 +115,80 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // the headers used connInfo.HTTP.Headers = r.Header + buf := bufpool.GetStd() - ctx := r.Context() - ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo) + ctx, cn := context.WithCancel(r.Context()) - // All checks passed, create a codec that reads directly from the request body - // until EOF, writes the response to w, and orders the server to process a - // single request. - w.Header().Set("content-type", contentType) + return &requestCodec{ + ctx: ctx, + cn: cn, + r: r, + w: w, + pi: connInfo, + requestBuffer: buf, + } + +} - codec := newHTTPServerConn(r, w, connInfo) - defer codec.Close() - s.serveSingleRequest(ctx, codec) +// gets the peer info +func (r *requestCodec) PeerInfo() codec.PeerInfo { + return r.pi } -// validateRequest returns a non-zero response code and error message if the -// request is invalid. -func validateRequest(r *http.Request) (int, error) { - if r.Method == http.MethodPut || r.Method == http.MethodDelete { - return http.StatusMethodNotAllowed, errors.New("method not allowed") +// json.RawMessage can be an array of requests. if it is, then it is a batch request +func (r *requestCodec) ReadBatch(ctx context.Context) (msgs json.RawMessage, err error) { + if r.r.Method == http.MethodGet { + return r.readBatchGet(ctx) } - if r.ContentLength > maxRequestContentLength { - err := fmt.Errorf("content length too large (%d>%d)", r.ContentLength, maxRequestContentLength) - return http.StatusRequestEntityTooLarge, err + if r.r.Method == http.MethodPost { + return r.readBatch(ctx) } - // Allow OPTIONS (regardless of content-type) - if r.Method == http.MethodOptions { - return 0, nil + return nil, fmt.Errorf("invalid request") +} + +func (r *requestCodec) readBatchGet(ctx context.Context) (msgs json.RawMessage, err error) { + method_up := r.r.URL.Query().Get("method") + params, _ := url.QueryUnescape(r.r.URL.Query().Get("params")) + param := []byte(params) + if pb, err := base64.URLEncoding.DecodeString(params); err == nil { + param = pb + } + req := jrpc.NewRequestInt(ctx, 1, method_up, json.RawMessage(param)) + return req.MarshalJSON() +} + +func (r *requestCodec) readBatch(ctx context.Context) (msgs json.RawMessage, err error) { + rd := io.LimitReader(r.r.Body, maxRequestContentLength) + _, err = io.Copy(r.requestBuffer, rd) + if err != nil { + return nil, err } - // Check content-type - if mt, _, err := mime.ParseMediaType(r.Header.Get("content-type")); err == nil { - for _, accepted := range acceptedContentTypes { - if accepted == mt { - return 0, nil - } - } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-r.ctx.Done(): + return nil, r.ctx.Err() } - // Invalid content-type ignored for now - return 0, nil - //err := fmt.Errorf("invalid content type, only %s is supported", contentType) - //return http.StatusUnsupportedMediaType, err + return json.RawMessage(r.requestBuffer.Bytes()), nil +} + +// closes the connection +func (r *requestCodec) Close() error { + r.cn() + bufpool.PutStd(r.requestBuffer) + return nil +} + +func (r *requestCodec) Write(p []byte) (n int, err error) { + return r.w.Write(p) +} + +// Closed returns a channel which is closed when the connection is closed. +func (r *requestCodec) Closed() <-chan struct{} { + return r.r.Context().Done() +} + +// RemoteAddr returns the peer address of the connection. +func (r *requestCodec) RemoteAddr() string { + return r.pi.RemoteAddr } diff --git a/codec/inproc/inproc.go b/codec/inproc/inproc.go index 0e5665a8156eb47090ead6cf7e641e2de433ef94..6522fe65b825de0956c005f5435255bfdabab16d 100644 --- a/codec/inproc/inproc.go +++ b/codec/inproc/inproc.go @@ -10,7 +10,8 @@ import ( ) type Codec struct { - done chan any + ctx context.Context + cn func() rd io.Reader wr io.Writer @@ -19,8 +20,10 @@ type Codec struct { func NewCodec() *Codec { rd, wr := io.Pipe() + ctx, cn := context.WithCancel(context.TODO()) return &Codec{ - done: make(chan interface{}), + ctx: ctx, + cn: cn, rd: bufio.NewReader(rd), wr: wr, msgs: make(chan json.RawMessage, 8), @@ -43,12 +46,14 @@ func (c *Codec) ReadBatch(ctx context.Context) (msgs json.RawMessage, err error) return ans, nil case <-ctx.Done(): return nil, ctx.Err() + case <-c.ctx.Done(): + return nil, c.ctx.Err() } } // closes the connection func (c *Codec) Close() error { - close(c.done) + c.cn() return nil } @@ -57,8 +62,8 @@ func (c *Codec) Write(p []byte) (n int, err error) { } // Closed returns a channel which is closed when the connection is closed. -func (c *Codec) Closed() <-chan any { - return c.done +func (c *Codec) Closed() <-chan struct{} { + return c.ctx.Done() } // RemoteAddr returns the peer address of the connection. diff --git a/codec/peer.go b/codec/peer.go index 2e9389a3618d8f92b0ea22d2984d1979139209a9..9dba5dfa603200ae4ed20baa5b211d328c99001d 100644 --- a/codec/peer.go +++ b/codec/peer.go @@ -2,14 +2,8 @@ package codec import "net/http" -// PeerInfo contains information about the remote end of the network connection. -// -// This is available within RPC method handlers through the context. Call -// PeerInfoFromContext to get information about the client connection related to -// the current method call. type PeerInfo struct { // Transport is name of the protocol used by the client. - // This can be "http", "ws" or "ipc". Transport string // Address of client. This will usually contain the IP address and port. diff --git a/codec/transport.go b/codec/transport.go index 12a26e0545b01cedfa563e6155889dfea31a8e10..e683473aa6c9b71fbb7a156ba07de8613c08280d 100644 --- a/codec/transport.go +++ b/codec/transport.go @@ -28,7 +28,7 @@ type Writer interface { // write json blob to stream io.Writer // Closed returns a channel which is closed when the connection is closed. - Closed() <-chan any + Closed() <-chan struct{} // RemoteAddr returns the peer address of the connection. RemoteAddr() string } diff --git a/codec/wire.go b/codec/wire.go index fae9404dba1d2d8f0d0e77636adba7e5c9abf92f..280935e4e4a58b7397443b2c6c1176dc35afed2d 100644 --- a/codec/wire.go +++ b/codec/wire.go @@ -1,6 +1,7 @@ package codec import ( + "bytes" "fmt" "strconv" @@ -40,16 +41,15 @@ func (Version) UnmarshalJSON(data []byte) error { // ID is a Request identifier. // -// Only one of either the Name or Number members will be set, using the -// number form if the Name is the empty string. // alternatively, ID can be null -type ID struct { - name string - number int64 +type ID json.RawMessage - null bool - - empty bool +func (i *ID) Format(f fmt.State, verb rune) { + if i == nil { + f.Write(Null) + return + } + f.Write(*i) } // compile time check whether the ID implements a fmt.Formatter, json.Marshaler and json.Unmarshaler interfaces. @@ -68,66 +68,47 @@ func NewStringID(v string) ID { return *NewStringIDPtr(v) } // NewStringID returns a new string request ID. func NewNullID() ID { return *NewNullIDPtr() } -func NewNumberIDPtr(v int64) *ID { return &ID{number: v} } +func NewNumberIDPtr(v int64) *ID { + o := ID(strconv.Itoa(int(v))) + return &o +} func NewStringIDPtr(v string) *ID { if v == "" { return nil } - return &ID{name: v} + o := ID(`"` + v + `"`) + return &o +} +func NewNullIDPtr() *ID { + o := ID("null") + return &o } -func NewNullIDPtr() *ID { return &ID{null: true} } func (id *ID) Number() int { if id == nil { return 0 } - if id.number == 0 { - ans, _ := strconv.Atoi(id.name) - return ans - } - return int(id.number) + ans, _ := strconv.Atoi(string(bytes.Trim(*id, `"'`))) + return ans } -// Format writes the ID to the formatter. -// -// If the rune is q the representation is non ambiguous, -// string forms are quoted, number forms are preceded by a #. -func (id *ID) Format(f fmt.State, r rune) { - numF, strF := `%d`, `%s` - if r == 'q' { - numF, strF = `#%d`, `%q` - } - - id.null = false - switch { - case id.name != "": - fmt.Fprintf(f, strF, id.name) - default: - fmt.Fprintf(f, numF, id.number) - } -} func (id *ID) IsNull() bool { if id == nil { - return true + return false } - return id.null + return len(*id) == 4 && + (*id)[0] == 'n' && + (*id)[1] == 'u' && + (*id)[2] == 'l' && + (*id)[3] == 'l' } // get the raw message func (id *ID) RawMessage() json.RawMessage { - if id.empty { - return nil - } if id == nil { return Null } - if id.null { - return Null - } - if id.name != "" { - return json.RawMessage(`"` + id.name + `"`) - } - return strconv.AppendInt(make([]byte, 0, 8), id.number, 10) + return json.RawMessage(*id) } // MarshalJSON implements json.Marshaler. @@ -137,13 +118,9 @@ func (id *ID) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. func (id *ID) UnmarshalJSON(data []byte) error { - *id = ID{} - if err := json.Unmarshal(data, &id.number); err == nil { - return nil - } - if err := json.Unmarshal(data, &id.name); err == nil { + if len(data) == 0 { return nil } - id.null = true + *id = data return nil } diff --git a/conn.go b/conn.go index 3ad298d196f0141b62aff6b9c8accd8bc7cf52fb..bb5e90ba1e4038695aceeb35a7a5bac436c28ba5 100644 --- a/conn.go +++ b/conn.go @@ -4,8 +4,8 @@ import "context" type Conn interface { Do(ctx context.Context, result any, method string, params any) error + Notify(ctx context.Context, method string, params any) error BatchCall(ctx context.Context, b ...*BatchElem) error - SetHeader(key, value string) Close() error } @@ -18,6 +18,9 @@ type StreamingConn interface { type BatchElem struct { Method string Params any + + IsNotification bool + // The result is unmarshaled into this field. Result must be set to a // non-nil pointer value of the desired type, otherwise the response will be // discarded. diff --git a/go.mod b/go.mod index cd02b566fd6e793342d37c140f7252c70620cb7f..bec5810595745c62e5dbe664fc0aa552b69c5ced 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( gfx.cafe/util/go/bufpool v0.0.0-20230121041905-80dafb1e973e + gfx.cafe/util/go/bytepool v0.0.0-20230502013805-237fcc25d586 gfx.cafe/util/go/frand v0.0.0-20230121041905-80dafb1e973e github.com/alecthomas/kong v0.7.1 github.com/davecgh/go-spew v1.1.1 diff --git a/go.sum b/go.sum index 6ecb1c9ce4c3715bec4182865759edb4caceb575..58da77f7ada26a6ca6d7380fb796676974aa7a59 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= gfx.cafe/util/go/bufpool v0.0.0-20230121041905-80dafb1e973e h1:cx35whzZb3wcLhmOUOiqz0N4f6o9ZWVnRe386rW9R5c= gfx.cafe/util/go/bufpool v0.0.0-20230121041905-80dafb1e973e/go.mod h1:+DiyiCOBGS9O9Ce4ewHQO3Y59h66WSWAbgZZ2O2AYYw= +gfx.cafe/util/go/bytepool v0.0.0-20230502013805-237fcc25d586 h1:M+l4yPLky17DgDUoAof+btd98EWGPYwWjOKPStPuhvU= +gfx.cafe/util/go/bytepool v0.0.0-20230502013805-237fcc25d586/go.mod h1:DEhDlgHR0UUZmqiF4bSm4Big/fV0bTUcOSim1WeXVvs= gfx.cafe/util/go/frand v0.0.0-20230121041905-80dafb1e973e h1:A62zlsu3HkEAVRIb+cCpRIpSTmd047+ABV1KC2RsI2U= gfx.cafe/util/go/frand v0.0.0-20230121041905-80dafb1e973e/go.mod h1:LNHxMJl0WnIr5+OChYxlVopxk+j7qxZv0XvWCzB6uGE= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= diff --git a/request.go b/request.go index e72a1ea31ee8eefeafe714e3594563b73cff722b..ea8fe18e5e5f51018cd9782bf3e8a016e3a9c828 100644 --- a/request.go +++ b/request.go @@ -55,12 +55,6 @@ func (r *Request) makeError(err error) *codec.Message { return m.ErrorResponse(err) } -func (r *Request) errorResponse(err error) *Response { - mw := NewReaderResponseWriterMsg(r) - mw.Send(nil, err) - return mw.Response() -} - func (r *Request) isNotification() bool { return r.ID == nil && len(r.Method) > 0 } diff --git a/server.go b/server.go index 456086cc974de572aba3207d29c869b952ae801b..1a23cba4699041f56873ec1cd6422868e8bd5567 100644 --- a/server.go +++ b/server.go @@ -41,9 +41,10 @@ func NewServer(r Handler) *Server { func (s *Server) printError(remote codec.ReaderWriter, err error) { if err != nil { - if s.Tracing.ErrorLogger != nil { - s.Tracing.ErrorLogger(remote, err) - } + return + } + if s.Tracing.ErrorLogger != nil { + s.Tracing.ErrorLogger(remote, err) } } @@ -82,6 +83,13 @@ func (s *Server) ServeCodec(pctx context.Context, remote codec.ReaderWriter) { } }() + go func() { + select { + case <-ctx.Done(): + remote.Close() + } + }() + for { msgs, err := remote.ReadBatch(ctx) if err != nil { @@ -100,7 +108,6 @@ func (s *Server) ServeCodec(pctx context.Context, remote codec.ReaderWriter) { } env.responses = append(env.responses, rw) } - wg := sync.WaitGroup{} wg.Add(len(msg)) for _, vv := range env.responses {