From ad983b300b138b2300402187c2e50f9afdc4b2a3 Mon Sep 17 00:00:00 2001
From: lightclient <14004106+lightclient@users.noreply.github.com>
Date: Tue, 27 Apr 2021 03:36:57 -0600
Subject: [PATCH] cmd/puppeth: add support for authentication via ssh agent
 (#22634)

---
 cmd/puppeth/ssh.go | 62 ++++++++++++++++++++++++++++------------------
 1 file changed, 38 insertions(+), 24 deletions(-)

diff --git a/cmd/puppeth/ssh.go b/cmd/puppeth/ssh.go
index da2862db2..039cb6cb4 100644
--- a/cmd/puppeth/ssh.go
+++ b/cmd/puppeth/ssh.go
@@ -30,6 +30,7 @@ import (
 
 	"github.com/ethereum/go-ethereum/log"
 	"golang.org/x/crypto/ssh"
+	"golang.org/x/crypto/ssh/agent"
 	"golang.org/x/crypto/ssh/terminal"
 )
 
@@ -43,6 +44,8 @@ type sshClient struct {
 	logger  log.Logger
 }
 
+const EnvSSHAuthSock = "SSH_AUTH_SOCK"
+
 // dial establishes an SSH connection to a remote node using the current user and
 // the user's configured private RSA key. If that fails, password authentication
 // is fallen back to. server can be a string like user:identity@server:port.
@@ -79,38 +82,49 @@ func dial(server string, pubkey []byte) (*sshClient, error) {
 	if username == "" {
 		username = user.Username
 	}
-	// Configure the supported authentication methods (private key and password)
-	var auths []ssh.AuthMethod
 
-	path := filepath.Join(user.HomeDir, ".ssh", identity)
-	if buf, err := ioutil.ReadFile(path); err != nil {
-		log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
+	// Configure the supported authentication methods (ssh agent, private key and password)
+	var (
+		auths []ssh.AuthMethod
+		conn  net.Conn
+	)
+	if conn, err = net.Dial("unix", os.Getenv(EnvSSHAuthSock)); err != nil {
+		log.Warn("Unable to dial SSH agent, falling back to private keys", "err", err)
 	} else {
-		key, err := ssh.ParsePrivateKey(buf)
-		if err != nil {
-			fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path)
-			blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
-			fmt.Println()
-			if err != nil {
-				log.Warn("Couldn't read password", "err", err)
-			}
-			key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob)
+		client := agent.NewClient(conn)
+		auths = append(auths, ssh.PublicKeysCallback(client.Signers))
+	}
+	if err != nil {
+		path := filepath.Join(user.HomeDir, ".ssh", identity)
+		if buf, err := ioutil.ReadFile(path); err != nil {
+			log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
+		} else {
+			key, err := ssh.ParsePrivateKey(buf)
 			if err != nil {
-				log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err)
+				fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path)
+				blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
+				fmt.Println()
+				if err != nil {
+					log.Warn("Couldn't read password", "err", err)
+				}
+				key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob)
+				if err != nil {
+					log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err)
+				} else {
+					auths = append(auths, ssh.PublicKeys(key))
+				}
 			} else {
 				auths = append(auths, ssh.PublicKeys(key))
 			}
-		} else {
-			auths = append(auths, ssh.PublicKeys(key))
 		}
-	}
-	auths = append(auths, ssh.PasswordCallback(func() (string, error) {
-		fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
-		blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
+		auths = append(auths, ssh.PasswordCallback(func() (string, error) {
+			fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
+			blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
 
-		fmt.Println()
-		return string(blob), err
-	}))
+			fmt.Println()
+			return string(blob), err
+		}))
+	}
 	// Resolve the IP address of the remote server
 	addr, err := net.LookupHost(hostname)
 	if err != nil {
-- 
GitLab