diff --git a/contrib/codecs/http/codec.go b/contrib/codecs/http/codec.go index ae14d5e9f1834412e0854b4e05ca05a107df6d38..3941c554c91e54f032dc877b3ce18644574d0020 100644 --- a/contrib/codecs/http/codec.go +++ b/contrib/codecs/http/codec.go @@ -40,12 +40,6 @@ type httpError struct { err error } -func NewCodec(w http.ResponseWriter, r *http.Request) *Codec { - c := &Codec{} - c.Reset(w, r) - return c -} - func (c *Codec) Reset(w http.ResponseWriter, r *http.Request) { c.wr = bufio.NewWriter(w) if w == nil { diff --git a/contrib/codecs/http/codecs.go b/contrib/codecs/http/codecs.go new file mode 100644 index 0000000000000000000000000000000000000000..c14e4a94ecef040d19698777e0c6e345f5ac0732 --- /dev/null +++ b/contrib/codecs/http/codecs.go @@ -0,0 +1,154 @@ +package http + +import ( + "context" + "encoding/base64" + "errors" + "io" + "net/http" + "net/url" + "strings" + + "gfx.cafe/open/jrpc/pkg/jsonrpc" + "gfx.cafe/open/jrpc/pkg/serverutil" +) + +var _ jsonrpc.ReaderWriter = (*HttpCodec)(nil) + +type HttpCodec struct { + ctx context.Context + cn context.CancelFunc + + r *http.Request + w http.ResponseWriter + i jsonrpc.PeerInfo + + f http.Flusher + + msgs *serverutil.Bundle +} + +func NewCodec(w http.ResponseWriter, r *http.Request) (*HttpCodec, error) { + switch r.Method { + case http.MethodGet: + return NewGetCodec(w, r), nil + case http.MethodPost: + return NewPostCodec(w, r) + default: + http.Error(w, "method not supported", http.StatusMethodNotAllowed) + return nil, errors.New("method not allowed") + } +} + +func NewGetCodec(w http.ResponseWriter, r *http.Request) *HttpCodec { + c := &HttpCodec{ + r: r, + w: w, + i: jsonrpc.PeerInfo{ + Transport: "http", + RemoteAddr: r.RemoteAddr, + HTTP: r.Clone(r.Context()), + }, + } + c.ctx, c.cn = context.WithCancel(r.Context()) + flusher, ok := w.(http.Flusher) + if ok { + c.f = flusher + } + + method_up := r.URL.Query().Get("method") + if method_up == "" { + method_up = strings.TrimPrefix(r.URL.Path, "/") + } + params, _ := url.QueryUnescape(r.URL.Query().Get("params")) + var param []byte + // try to read params as base64 + if pb, err := base64.URLEncoding.DecodeString(params); err == nil { + param = pb + } else { + // otherwise just take them raw + param = []byte(params) + } + id := r.URL.Query().Get("id") + if id == "" { + id = "1" + } + c.msgs = &serverutil.Bundle{ + Messages: []*jsonrpc.Message{{ + ID: jsonrpc.NewId(id), + Method: method_up, + Params: param, + }}, + Batch: false, + } + return c +} + +func NewPostCodec(w http.ResponseWriter, r *http.Request) (*HttpCodec, error) { + c := &HttpCodec{ + r: r, + w: w, + i: jsonrpc.PeerInfo{ + Transport: "http", + RemoteAddr: r.RemoteAddr, + HTTP: r.Clone(r.Context()), + }, + } + c.ctx, c.cn = context.WithCancel(r.Context()) + flusher, ok := w.(http.Flusher) + if ok { + c.f = flusher + } + + data, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + c.msgs = serverutil.ParseBundle(data) + + return c, nil +} + +// gets the peer info +func (c *HttpCodec) PeerInfo() jsonrpc.PeerInfo { + return c.i +} + +func (c *HttpCodec) ReadBatch(ctx context.Context) ([]*jsonrpc.Message, bool, error) { + if c.msgs == nil { + return nil, false, io.EOF + } + c.msgs = nil + return c.msgs.Messages, c.msgs.Batch, nil +} + +// closes the connection +func (c *HttpCodec) Write(p []byte) (n int, err error) { + return c.w.Write(p) +} + +func (c *HttpCodec) Flush() error { + c.w.Write([]byte{'\n'}) + if c.f != nil { + c.f.Flush() + } + return nil +} + +func (c *HttpCodec) Close() error { + if c.f != nil { + c.f.Flush() + } + c.cn() + return nil +} + +// Closed returns a channel which is closed when the connection is closed. +func (c *HttpCodec) Closed() <-chan struct{} { + return c.ctx.Done() +} + +// RemoteAddr returns the peer address of the connection. +func (c *HttpCodec) RemoteAddr() string { + return c.r.RemoteAddr +} diff --git a/contrib/codecs/http/handler.go b/contrib/codecs/http/handler.go index 8ce711e7be3b5eb2aced347d1f0c6af069ce6578..afb0e007b0c19f64338cee28057d6c3418f33ebf 100644 --- a/contrib/codecs/http/handler.go +++ b/contrib/codecs/http/handler.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "sync" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -13,30 +12,22 @@ import ( ) func HttpHandler(s *server.Server) http.Handler { - return h2c.NewHandler(&Server{Server: s}, &http2.Server{}) -} - -type Server struct { - Server *server.Server -} - -var codecPool = sync.Pool{ - New: func() any { - return &Codec{} - }, -} - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if s.Server == nil { - http.Error(w, "no server set", http.StatusInternalServerError) - return - } - c := NewCodec(w, r) - w.Header().Set("content-type", contentType) - err := s.Server.ServeCodec(r.Context(), c) - if err != nil && !errors.Is(err, context.Canceled) { - // slog.Error("codec err", "err", err) - http.Error(w, "Internal Error", http.StatusInternalServerError) - } - <-c.Closed() + return h2c.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if s == nil { + http.Error(w, "no server set", http.StatusInternalServerError) + return + } + c, err := NewCodec(w, r) + if err != nil { + return + } + w.Header().Set("content-type", contentType) + err = s.ServeCodec(r.Context(), c) + if err != nil && !errors.Is(err, context.Canceled) { + // slog.Error("codec err", "err", err) + http.Error(w, "Internal Error", http.StatusInternalServerError) + return + } + <-c.Closed() + }), &http2.Server{}) }