From aca755ab4330e324da57d82ac32bd0912f23d8c3 Mon Sep 17 00:00:00 2001
From: leonardchinonso <36096513+leonardchinonso@users.noreply.github.com>
Date: Sat, 12 Feb 2022 12:40:19 +0100
Subject: [PATCH] Fix/rpcdaemon ws upgrade (#3490)

* Added a method to `cmd/rpcdaemon/cli/config.go` to check header of incoming http request for a ws upgrade request

* Added the testing of the 'GET /' request for ws to the devnet tool

* Fixed lint errors
---
 cmd/devnettest/commands/requests.go          | 18 ++++++++++
 cmd/devnettest/requests/mock_requests.go     | 16 +++++++++
 cmd/devnettest/requests/request_generator.go | 35 ++++++++++++++++++++
 cmd/rpcdaemon/cli/config.go                  |  9 ++++-
 4 files changed, 77 insertions(+), 1 deletion(-)
 create mode 100644 cmd/devnettest/commands/requests.go
 create mode 100644 cmd/devnettest/requests/mock_requests.go

diff --git a/cmd/devnettest/commands/requests.go b/cmd/devnettest/commands/requests.go
new file mode 100644
index 0000000000..a8d3a7f9aa
--- /dev/null
+++ b/cmd/devnettest/commands/requests.go
@@ -0,0 +1,18 @@
+package commands
+
+import (
+	"github.com/ledgerwatch/erigon/cmd/devnettest/requests"
+	"github.com/spf13/cobra"
+)
+
+func init() {
+	rootCmd.AddCommand(mockRequestCmd)
+}
+
+var mockRequestCmd = &cobra.Command{
+	Use:   "mock",
+	Short: "Mocks a request on the devnet",
+	Run: func(cmd *cobra.Command, args []string) {
+		requests.MockGetRequest(reqId)
+	},
+}
diff --git a/cmd/devnettest/requests/mock_requests.go b/cmd/devnettest/requests/mock_requests.go
new file mode 100644
index 0000000000..834d6ab247
--- /dev/null
+++ b/cmd/devnettest/requests/mock_requests.go
@@ -0,0 +1,16 @@
+package requests
+
+import "fmt"
+
+func MockGetRequest(reqId int) {
+	reqGen := initialiseRequestGenerator(reqId)
+
+	res := reqGen.Get()
+
+	if res.Err != nil {
+		fmt.Printf("error: %v\n", res.Err)
+		return
+	}
+
+	fmt.Printf("OK\n")
+}
diff --git a/cmd/devnettest/requests/request_generator.go b/cmd/devnettest/requests/request_generator.go
index 7138c223dc..9075eff2b0 100644
--- a/cmd/devnettest/requests/request_generator.go
+++ b/cmd/devnettest/requests/request_generator.go
@@ -1,7 +1,9 @@
 package requests
 
 import (
+	"errors"
 	"fmt"
+	"io/ioutil"
 	"net/http"
 	"time"
 
@@ -34,6 +36,39 @@ func initialiseRequestGenerator(reqId int) *RequestGenerator {
 	return &reqGen
 }
 
+func (req *RequestGenerator) Get() rpctest.CallResult {
+	start := time.Now()
+	res := rpctest.CallResult{
+		RequestID: req.reqID,
+	}
+
+	resp, err := http.Get(erigonUrl)
+	if err != nil {
+		res.Took = time.Since(start)
+		res.Err = err
+		return res
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != 200 {
+		res.Took = time.Since(start)
+		res.Err = errors.New("bad request")
+		return res
+	}
+
+	body, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		res.Took = time.Since(start)
+		res.Err = err
+		return res
+	}
+
+	res.Response = body
+	res.Took = time.Since(start)
+	res.Err = err
+	return res
+}
+
 func (req *RequestGenerator) Erigon(method, body string, response interface{}) rpctest.CallResult {
 	return req.call(erigonUrl, method, body, response)
 }
diff --git a/cmd/rpcdaemon/cli/config.go b/cmd/rpcdaemon/cli/config.go
index ae3760bdf6..eaa2031a99 100644
--- a/cmd/rpcdaemon/cli/config.go
+++ b/cmd/rpcdaemon/cli/config.go
@@ -9,6 +9,7 @@ import (
 	"net"
 	"net/http"
 	"path"
+	"strings"
 	"time"
 
 	"github.com/ledgerwatch/erigon-lib/gointerfaces"
@@ -498,13 +499,19 @@ func StartRpcServer(ctx context.Context, cfg Flags, rpcAPI []rpc.API) error {
 	return nil
 }
 
+// isWebsocket checks the header of a http request for a websocket upgrade request.
+func isWebsocket(r *http.Request) bool {
+	return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" &&
+		strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade")
+}
+
 func createHandler(cfg Flags, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler) http.Handler {
 	var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		// adding a healthcheck here
 		if health.ProcessHealthcheckIfNeeded(w, r, apiList) {
 			return
 		}
-		if cfg.WebsocketEnabled && wsHandler != nil && r.Method == "GET" {
+		if cfg.WebsocketEnabled && wsHandler != nil && isWebsocket(r) {
 			wsHandler.ServeHTTP(w, r)
 			return
 		}
-- 
GitLab