From 7d38d53ae449c6ec06f7b0579f1a189b02222a60 Mon Sep 17 00:00:00 2001
From: Nilesh Trivedi <nilesh@hypertrack.io>
Date: Mon, 20 Aug 2018 19:24:38 +0530
Subject: [PATCH] cmd/puppeth: accept ssh identity in the server string
 (#17407)

* cmd/puppeth: Accept identityfile in the server string with fallback to id_rsa

* cmd/puppeth: code polishes + fix heath check double ports
---
 cmd/puppeth/ssh.go            | 52 ++++++++++++++++++++---------------
 cmd/puppeth/wizard_network.go |  8 +++---
 2 files changed, 34 insertions(+), 26 deletions(-)

diff --git a/cmd/puppeth/ssh.go b/cmd/puppeth/ssh.go
index 158261ce0..c50759606 100644
--- a/cmd/puppeth/ssh.go
+++ b/cmd/puppeth/ssh.go
@@ -45,33 +45,44 @@ type sshClient struct {
 
 // dial establishes an SSH connection to a remote node using the current user and
 // the user's configured private RSA key. If that fails, password authentication
-// is fallen back to. The caller may override the login user via user@server:port.
+// is fallen back to. server can be a string like user:identity@server:port.
 func dial(server string, pubkey []byte) (*sshClient, error) {
-	// Figure out a label for the server and a logger
-	label := server
-	if strings.Contains(label, ":") {
-		label = label[:strings.Index(label, ":")]
-	}
-	login := ""
+	// Figure out username, identity, hostname and port
+	hostname := ""
+	hostport := server
+	username := ""
+	identity := "id_rsa" // default
+
 	if strings.Contains(server, "@") {
-		login = label[:strings.Index(label, "@")]
-		label = label[strings.Index(label, "@")+1:]
-		server = server[strings.Index(server, "@")+1:]
+		prefix := server[:strings.Index(server, "@")]
+		if strings.Contains(prefix, ":") {
+			username = prefix[:strings.Index(prefix, ":")]
+			identity = prefix[strings.Index(prefix, ":")+1:]
+		} else {
+			username = prefix
+		}
+		hostport = server[strings.Index(server, "@")+1:]
 	}
-	logger := log.New("server", label)
+	if strings.Contains(hostport, ":") {
+		hostname = hostport[:strings.Index(hostport, ":")]
+	} else {
+		hostname = hostport
+		hostport += ":22"
+	}
+	logger := log.New("server", server)
 	logger.Debug("Attempting to establish SSH connection")
 
 	user, err := user.Current()
 	if err != nil {
 		return nil, err
 	}
-	if login == "" {
-		login = user.Username
+	if username == "" {
+		username = user.Username
 	}
 	// Configure the supported authentication methods (private key and password)
 	var auths []ssh.AuthMethod
 
-	path := filepath.Join(user.HomeDir, ".ssh", "id_rsa")
+	path := filepath.Join(user.HomeDir, ".ssh", identity)
 	if buf, err := ioutil.ReadFile(path); err != nil {
 		log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
 	} else {
@@ -94,14 +105,14 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
 		}
 	}
 	auths = append(auths, ssh.PasswordCallback(func() (string, error) {
-		fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server)
+		fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
 		blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
 
 		fmt.Println()
 		return string(blob), err
 	}))
 	// Resolve the IP address of the remote server
-	addr, err := net.LookupHost(label)
+	addr, err := net.LookupHost(hostname)
 	if err != nil {
 		return nil, err
 	}
@@ -109,10 +120,7 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
 		return nil, errors.New("no IPs associated with domain")
 	}
 	// Try to dial in to the remote server
-	logger.Trace("Dialing remote SSH server", "user", login)
-	if !strings.Contains(server, ":") {
-		server += ":22"
-	}
+	logger.Trace("Dialing remote SSH server", "user", username)
 	keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
 		// If no public key is known for SSH, ask the user to confirm
 		if pubkey == nil {
@@ -139,13 +147,13 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
 		// We have a mismatch, forbid connecting
 		return errors.New("ssh key mismatch, readd the machine to update")
 	}
-	client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck})
+	client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck})
 	if err != nil {
 		return nil, err
 	}
 	// Connection established, return our utility wrapper
 	c := &sshClient{
-		server:  label,
+		server:  hostname,
 		address: addr[0],
 		pubkey:  pubkey,
 		client:  client,
diff --git a/cmd/puppeth/wizard_network.go b/cmd/puppeth/wizard_network.go
index d780c550b..c0ddcc2a3 100644
--- a/cmd/puppeth/wizard_network.go
+++ b/cmd/puppeth/wizard_network.go
@@ -62,14 +62,14 @@ func (w *wizard) manageServers() {
 	}
 }
 
-// makeServer reads a single line from stdin and interprets it as a hostname to
-// connect to. It tries to establish a new SSH session and also executing some
-// baseline validations.
+// makeServer reads a single line from stdin and interprets it as
+// username:identity@hostname to connect to. It tries to establish a
+// new SSH session and also executing some baseline validations.
 //
 // If connection succeeds, the server is added to the wizards configs!
 func (w *wizard) makeServer() string {
 	fmt.Println()
-	fmt.Println("Please enter remote server's address:")
+	fmt.Println("What is the remote server's address ([username[:identity]@]hostname[:port])?")
 
 	// Read and dial the server to ensure docker is present
 	input := w.readString()
-- 
GitLab