From 375b34f1ea96e87f562cc9bea12fc7a16e66acc9 Mon Sep 17 00:00:00 2001 From: rinetd Date: Tue, 6 Nov 2018 15:37:35 +0800 Subject: [PATCH] fix some error --- README.md | 2 + auth.go | 76 ++++++++++++++++++--------- client.go | 138 +++++++++++++++++++++++++++++++------------------ client_test.go | 52 +++++++++++++++++++ config.go | 12 +++-- ssh.go | 49 ++++++++++++++++++ 6 files changed, 250 insertions(+), 79 deletions(-) create mode 100644 client_test.go create mode 100644 ssh.go diff --git a/README.md b/README.md index 46d8f43..92b16e7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ // rsync -av src dst ./src/* --> /root/dst/src/* // rsync -av src dst/ ./src/* --> /root/dst/src/* +## Install +`go get github.com/pytool/ssh` ## Example ### 在远程执行ssh命令 diff --git a/auth.go b/auth.go index 6a353f9..c78940d 100644 --- a/auth.go +++ b/auth.go @@ -14,6 +14,23 @@ import ( "golang.org/x/crypto/ssh/terminal" ) +//HasAgent reports whether the SSH agent is available +func HasAgent() bool { + authsock, ok := os.LookupEnv("SSH_AUTH_SOCK") + if !ok { + return false + } + if dirent, err := os.Stat(authsock); err != nil { + if os.IsNotExist(err) { + return false + } + if dirent.Mode()&os.ModeSocket == 0 { + return false + } + } + return true +} + // An implementation of ssh.KeyboardInteractiveChallenge that simply sends // back the password for all questions. The questions are logged. func passwordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallenge { @@ -48,7 +65,7 @@ func WithPassword(password string) (ssh.AuthMethod, error) { // WithAgent use already authed user func WithAgent() (ssh.AuthMethod, error) { sock := os.Getenv("SSH_AUTH_SOCK") - if sock != "" { + if sock == "" { // fmt.Println(errors.New("Agent Disabled")) return nil, errors.New("Agent Disabled") } @@ -72,23 +89,28 @@ func WithAgent() (ssh.AuthMethod, error) { // return nil } -// WithPrivateKeys 设置多个 ~/.ssh/id_rsa -func WithPrivateKeys(keyFiles []string, password string) (ssh.AuthMethod, error) { +// WithPrivateKeys 设置多个 ~/.ssh/id_rsa ,如果加密用passphrase尝试 +func WithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, error) { var signers []ssh.Signer for _, key := range keyFiles { - buffer, err := ioutil.ReadFile(key) + pemBytes, err := ioutil.ReadFile(key) if err != nil { println(err.Error()) // return } - signer, err := ssh.ParsePrivateKeyWithPassphrase([]byte(buffer), []byte(password)) + signer, err := ssh.ParsePrivateKey([]byte(pemBytes)) if err != nil { - println(err.Error()) - } else { - signers = append(signers, signer) + if strings.Contains(err.Error(), "cannot decode encrypted private keys") { + if signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(passphrase)); err != nil { + continue + } + } + // println(err.Error()) } + signers = append(signers, signer) + } if signers == nil { return nil, errors.New("WithPrivateKeys: no keyfiles input") @@ -97,7 +119,7 @@ func WithPrivateKeys(keyFiles []string, password string) (ssh.AuthMethod, error) } // WithPrivateKey 自动监测是否带有密码 -func WithPrivateKey(keyfile string, password string) (ssh.AuthMethod, error) { +func WithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) { pemBytes, err := ioutil.ReadFile(keyfile) if err != nil { println(err.Error()) @@ -108,12 +130,15 @@ func WithPrivateKey(keyfile string, password string) (ssh.AuthMethod, error) { signer, err = ssh.ParsePrivateKey(pemBytes) if err != nil { if strings.Contains(err.Error(), "cannot decode encrypted private keys") { - if signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password)); err == nil { + signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(passphrase)) + if err == nil { return ssh.PublicKeys(signer), nil } } + return nil, err } - return nil, err + return ssh.PublicKeys(signer), nil + } // WithPrivateKeyString 直接通过字符串 @@ -134,27 +159,30 @@ func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) { // WithPrivateKeyTerminal 通过终端读取带密码的 PublicKey func WithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) { - // fmt.Fprintf(os.Stderr, "This SSH key is encrypted. Please enter passphrase for key '%s':", priv.path) - passphrase, err := terminal.ReadPassword(int(syscall.Stdin)) - if err != nil { - println(err.Error()) - return nil, err - } - - fmt.Fprintln(os.Stderr) - pemBytes, err := ioutil.ReadFile(keyfile) if err != nil { println(err.Error()) return nil, err } - signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, passphrase) - if err != nil { - fmt.Println(err) + var signer ssh.Signer + signer, err = ssh.ParsePrivateKey(pemBytes) + if err != nil { + if strings.Contains(err.Error(), "cannot decode encrypted private keys") { + + fmt.Fprintf(os.Stderr, "This SSH key is encrypted. Please enter passphrase for key '%s':", keyfile) + passphrase, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + // println(err.Error()) + return nil, err + } + fmt.Fprintln(os.Stderr) + if signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(passphrase)); err == nil { + return ssh.PublicKeys(signer), nil + } + } return nil, err } - return ssh.PublicKeys(signer), nil } diff --git a/client.go b/client.go index e6a401f..f52829e 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package ssh import ( "errors" + "fmt" "net" "os" "strconv" @@ -19,58 +20,40 @@ type Client struct { SFTPClient *sftp.Client } -// NewClient 根据配置 -func NewClient(user, host, port, password string) (client *Client, err error) { - p, err := strconv.Atoi(port) - if err == nil || p == 0 { - p = 22 - } - if user == "" { - user = "root" - } - var config = &Config{ - User: user, - Host: host, - Port: p, - Password: password, - // KeyFiles: []string{"~/.ssh/id_rsa"}, - } - return New(config) -} - // New 创建SSH client -func New(config *Config) (client *Client, err error) { +func New(cnf *Config) (client *Client, err error) { clientConfig := &ssh.ClientConfig{ - User: config.User, + User: cnf.User, Timeout: DefaultTimeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - // 2. 密码方式 - if config.Password != "" { - clientConfig.Auth = append(clientConfig.Auth, ssh.Password(config.Password)) + if cnf.Port == 0 { + cnf.Port = 22 } - // 3. privite key file - if len(config.KeyFiles) == 0 { + // 1. privite key file + if len(cnf.KeyFiles) == 0 { keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" - if auth, err := WithPrivateKey(keyPath, config.Password); err != nil { + if auth, err := WithPrivateKey(keyPath, cnf.Passphrase); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } } else { - if auth, err := WithPrivateKeys(config.KeyFiles, config.Password); err != nil { + if auth, err := WithPrivateKeys(cnf.KeyFiles, cnf.Passphrase); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } } - // 1. agent - if auth, err := WithAgent(); err != nil { + // 2. 密码方式 放在key之后,这样密钥失败之后可以使用Password方式 + if cnf.Password != "" { + clientConfig.Auth = append(clientConfig.Auth, ssh.Password(cnf.Password)) + } + // 3. agent 模式放在最后,这样当前两者都不能使用时可以采用Agent模式 + if auth, err := WithAgent(); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } - if config.Port == 0 { - config.Port = 22 - } + // hostPort := config.Host + ":" + strconv.Itoa(config.Port) - sshClient, err := ssh.Dial("tcp", net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), clientConfig) + sshClient, err := ssh.Dial("tcp", net.JoinHostPort(cnf.Host, strconv.Itoa(cnf.Port)), clientConfig) if err != nil { return client, errors.New("Failed to dial ssh: " + err.Error()) @@ -85,15 +68,81 @@ func New(config *Config) (client *Client, err error) { return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil } -// Execute cmd on the remote host and return stderr and stdout -func (c *Client) Exec(cmd string) ([]byte, error) { - session, err := c.SSHClient.NewSession() +// NewClient 根据配置 +func NewClient(host, port, user, password string) (client *Client, err error) { + p, _ := strconv.Atoi(port) + // if err != nil { + // p = 22 + // } + if user == "" { + user = "root" + } + var config = &Config{ + Host: host, + Port: p, + User: user, + Password: password, + // KeyFiles: []string{"~/.ssh/id_rsa"}, + Passphrase: password, + } + return New(config) +} + +func NewWithAgent(Host, Port, User string) (client *Client, err error) { + clientConfig := &ssh.ClientConfig{ + User: User, + Timeout: DefaultTimeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + auth, err := WithAgent() if err != nil { return nil, err } - defer session.Close() + clientConfig.Auth = append(clientConfig.Auth, auth) + // hostPort := config.Host + ":" + strconv.Itoa(config.Port) + sshClient, err := ssh.Dial("tcp", net.JoinHostPort(Host, Port), clientConfig) + + if err != nil { + return client, errors.New("Failed to dial ssh: " + err.Error()) + } + + // create sftp client + var sftpClient *sftp.Client + if sftpClient, err = sftp.NewClient(sshClient); err != nil { + return client, errors.New("Failed to conn sftp: " + err.Error()) + } + return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil + +} +func NewWithPrivateKey(Host, Port, User, Passphrase string) (client *Client, err error) { + clientConfig := &ssh.ClientConfig{ + User: User, + Timeout: DefaultTimeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + // 3. privite key file + keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" + auth, err := WithPrivateKey(keyPath, Passphrase) + if err != nil { + fmt.Println(err) + return nil, err + } + clientConfig.Auth = append(clientConfig.Auth, auth) + + // hostPort := config.Host + ":" + strconv.Itoa(config.Port) + sshClient, err := ssh.Dial("tcp", net.JoinHostPort(Host, Port), clientConfig) + + if err != nil { + return client, errors.New("Failed to dial ssh: " + err.Error()) + } + + // create sftp client + var sftpClient *sftp.Client + if sftpClient, err = sftp.NewClient(sshClient); err != nil { + return client, errors.New("Failed to conn sftp: " + err.Error()) + } + return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil - return session.CombinedOutput(cmd) } // Close the underlying SSH connection @@ -101,14 +150,3 @@ func (c *Client) Close() { c.SFTPClient.Close() c.SSHClient.Close() } - -func addPortToHost(host string) string { - _, _, err := net.SplitHostPort(host) - - // We got an error so blindly try to add a port number - if err != nil { - return net.JoinHostPort(host, "22") - } - - return host -} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..ced679f --- /dev/null +++ b/client_test.go @@ -0,0 +1,52 @@ +package ssh + +import ( + "fmt" + "testing" +) + +func TestNewWithAgent(t *testing.T) { + + c, err := NewWithAgent("118.190.117.250", "3009", "root") + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + b, err := c.Run("id") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(b)) +} + +func TestNewClient(t *testing.T) { + c, err := NewClient("192.168.5.154", "22", "root", "123456") + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + b, err := c.Run("id") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(b)) +} + +func TestNewWithPrivateKey(t *testing.T) { + c, err := NewWithPrivateKey("192.168.5.154", "22", "root", "123456") + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + b, err := c.Run("id") + if err != nil { + fmt.Println(err) + return + } + fmt.Println(string(b)) +} diff --git a/config.go b/config.go index bdedbeb..a6cdaf2 100644 --- a/config.go +++ b/config.go @@ -5,12 +5,14 @@ import ( ) type Config struct { - User string - Host string - Port int - Password string - KeyFiles []string + User string + Host string + Port int + Password string + KeyFiles []string + Passphrase string + StickySession bool // DisableAgentForwarding, if true, will not forward the SSH agent. DisableAgentForwarding bool diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..1539d1b --- /dev/null +++ b/ssh.go @@ -0,0 +1,49 @@ +package ssh + +import ( + "fmt" + "path/filepath" +) + +// Run Execute cmd on the remote host and return stderr and stdout +func (c *Client) Run(cmd string) ([]byte, error) { + session, err := c.SSHClient.NewSession() + if err != nil { + return nil, err + } + defer session.Close() + + return session.CombinedOutput(cmd) +} + +//Exec Execute cmd on the remote host and return stderr and stdout +func (c *Client) Exec(cmd string) ([]byte, error) { + session, err := c.SSHClient.NewSession() + if err != nil { + return nil, err + } + defer session.Close() + + return session.CombinedOutput(cmd) +} + +// RunScript Executes a shell script file on the remote machine. +// It is copied in the tmp folder and ran in a single session. +// chmod +x is applied before running. +// Returns an SshResponse and an error if any has occured +func (c *Client) RunScript(scriptPath string) ([]byte, error) { + session, err := c.SSHClient.NewSession() + if err != nil { + return nil, err + } + defer session.Close() + + // 1. 上传 script + remotePath := fmt.Sprintf("/tmp/%s", filepath.Base(scriptPath)) + if err := c.UploadFile(scriptPath, remotePath); err != nil { + return nil, err + } + // 2. 执行script + rCmd := fmt.Sprintf("chmod +x %s ; %s", remotePath, remotePath) + return session.CombinedOutput(rCmd) +}