From a2d60ca1f3e76ede2d2e422d09998022ef480bf6 Mon Sep 17 00:00:00 2001 From: rinetd Date: Tue, 30 Oct 2018 17:45:30 +0800 Subject: [PATCH] golang lib ssh --- README.md | 95 ++++++++ auth.go | 160 +++++++++++++ client.go | 114 +++++++++ cmd/main.go | 24 ++ config.go | 80 +++++++ example/get/main.go | 24 ++ example/put/main.go | 22 ++ sftp.go | 399 ++++++++++++++++++++++++++++++++ sftp_is_test.go | 61 +++++ sftp_test.go | 193 +++++++++++++++ sudo.go | 70 ++++++ test/upload/dir/file | 1 + test/upload/dir/subfold/subfile | 1 + test/upload/file | 1 + 14 files changed, 1245 insertions(+) create mode 100644 README.md create mode 100644 auth.go create mode 100644 client.go create mode 100644 cmd/main.go create mode 100644 config.go create mode 100644 example/get/main.go create mode 100644 example/put/main.go create mode 100644 sftp.go create mode 100644 sftp_is_test.go create mode 100644 sftp_test.go create mode 100644 sudo.go create mode 100644 test/upload/dir/file create mode 100644 test/upload/dir/subfold/subfile create mode 100644 test/upload/file diff --git a/README.md b/README.md new file mode 100644 index 0000000..67c2e7f --- /dev/null +++ b/README.md @@ -0,0 +1,95 @@ + +## 项目简介 +本项目是基于golang标准库 ssh 和 sftp 开发 + +本项目是对标准库进行一个简单的高层封装,使得可以在在 Windows Linux Mac 上非常容易的执行 ssh 命令, +以及文件,文件夹的上传,下载等操作. + +文件上传下载模仿rsync方式: 只和源有关. +// rsync -av src/ dst ./src/* --> /root/dst/* +// rsync -av src/ dst/ ./src/* --> /root/dst/* +// rsync -av src dst ./src/* --> /root/dst/src/* +// rsync -av src dst/ ./src/* --> /root/dst/src/* + +## Example + +### 在远程执行ssh命令 +```go +package main +import ( + "fmt" + "github.com/pytool/ssh" +) +func main() { + + c, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + if err != nil { + panic(err) + } + defer c.Close() + + output, err := c.Exec("uptime") + if err != nil { + panic(err) + } + + fmt.Printf("Uptime: %s\n", output) +} + +``` +### 文件下载 +```go +package main + +import ( + "github.com/pytool/ssh" +) + +func main() { + + client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + if err != nil { + panic(err) + } + defer client.Close() + var remotedir = "/root/test/" + // download dir + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/" + client.Download(remotedir, local) + + // upload file + var remotefile = "/root/test/file" + + client.Download(remotefile, local) +} + +``` + +### 文件上传 +```go +package main + +import ( + "github.com/pytool/ssh" +) + +func main() { + + client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + if err != nil { + panic(err) + } + defer client.Close() + var remotedir = "/root/test/" + // upload dir + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/" + client.Upload(local, remotedir) + + // upload file + local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/file" + client.Upload(local, remotedir) +} + +``` + + diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..6a353f9 --- /dev/null +++ b/auth.go @@ -0,0 +1,160 @@ +package ssh + +import ( + "errors" + "fmt" + "io/ioutil" + "net" + "os" + "strings" + "syscall" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/terminal" +) + +// An implementation of ssh.KeyboardInteractiveChallenge that simply sends +// back the password for all questions. The questions are logged. +func passwordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallenge { + return func(user, instruction string, questions []string, echos []bool) ([]string, error) { + // log.Printf("Keyboard interactive challenge: ") + // log.Printf("-- User: %s", user) + // log.Printf("-- Instructions: %s", instruction) + // for i, question := range questions { + // log.Printf("-- Question %d: %s", i+1, question) + // } + + // Just send the password back for all questions + answers := make([]string, len(questions)) + for i := range answers { + answers[i] = password + } + + return answers, nil + } +} + +// WithKeyboardPassword Generate a password-auth'd ssh ClientConfig +func WithKeyboardPassword(password string) (ssh.AuthMethod, error) { + return ssh.KeyboardInteractive(passwordKeyboardInteractive(password)), nil +} + +// WithPassword Generate a password-auth'd ssh ClientConfig +func WithPassword(password string) (ssh.AuthMethod, error) { + return ssh.Password(password), nil +} + +// WithAgent use already authed user +func WithAgent() (ssh.AuthMethod, error) { + sock := os.Getenv("SSH_AUTH_SOCK") + if sock != "" { + // fmt.Println(errors.New("Agent Disabled")) + return nil, errors.New("Agent Disabled") + } + socks, err := net.Dial("unix", sock) + if err != nil { + fmt.Println(err) + return nil, err + } + // 1. 返回Signers函数的结果 + agent := agent.NewClient(socks) + signers, err := agent.Signers() + return ssh.PublicKeys(signers...), nil + // 2. 返回Signers函数 + // getSigners := agent.NewClient(socks).Signers + // return ssh.PublicKeysCallback(getSigners), nil + + // 3.简写方式 + // if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + // return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + // } + // return nil +} + +// WithPrivateKeys 设置多个 ~/.ssh/id_rsa +func WithPrivateKeys(keyFiles []string, password string) (ssh.AuthMethod, error) { + var signers []ssh.Signer + + for _, key := range keyFiles { + + buffer, err := ioutil.ReadFile(key) + if err != nil { + println(err.Error()) + // return + } + signer, err := ssh.ParsePrivateKeyWithPassphrase([]byte(buffer), []byte(password)) + if err != nil { + println(err.Error()) + } else { + signers = append(signers, signer) + } + } + if signers == nil { + return nil, errors.New("WithPrivateKeys: no keyfiles input") + } + return ssh.PublicKeys(signers...), nil +} + +// WithPrivateKey 自动监测是否带有密码 +func WithPrivateKey(keyfile string, password string) (ssh.AuthMethod, error) { + pemBytes, err := ioutil.ReadFile(keyfile) + if err != nil { + println(err.Error()) + return nil, err + } + + var signer ssh.Signer + signer, err = ssh.ParsePrivateKey(pemBytes) + if err != nil { + if strings.Contains(err.Error(), "cannot decode encrypted private keys") { + if signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password)); err == nil { + return ssh.PublicKeys(signer), nil + } + } + } + return nil, err +} + +// WithPrivateKeyString 直接通过字符串 +func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) { + var signer ssh.Signer + var err error + if password == "" { + signer, err = ssh.ParsePrivateKey([]byte(key)) + } else { + signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(key), []byte(password)) + } + if err != nil { + println(err.Error()) + return nil, err + } + return ssh.PublicKeys(signer), nil +} + +// WithPrivateKeyTerminal 通过终端读取带密码的 PublicKey +func WithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) { + // fmt.Fprintf(os.Stderr, "This SSH key is encrypted. Please enter passphrase for key '%s':", priv.path) + passphrase, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + println(err.Error()) + return nil, err + } + + fmt.Fprintln(os.Stderr) + + pemBytes, err := ioutil.ReadFile(keyfile) + if err != nil { + + println(err.Error()) + return nil, err + } + signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, passphrase) + if err != nil { + + fmt.Println(err) + return nil, err + } + + return ssh.PublicKeys(signer), nil +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..e6a401f --- /dev/null +++ b/client.go @@ -0,0 +1,114 @@ +package ssh + +import ( + "errors" + "net" + "os" + "strconv" + "time" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +const DefaultTimeout = 30 * time.Second + +type Client struct { + *Config + SSHClient *ssh.Client + SFTPClient *sftp.Client +} + +// NewClient 根据配置 +func NewClient(user, host, port, password string) (client *Client, err error) { + p, err := strconv.Atoi(port) + if err == nil || p == 0 { + p = 22 + } + if user == "" { + user = "root" + } + var config = &Config{ + User: user, + Host: host, + Port: p, + Password: password, + // KeyFiles: []string{"~/.ssh/id_rsa"}, + } + return New(config) +} + +// New 创建SSH client +func New(config *Config) (client *Client, err error) { + clientConfig := &ssh.ClientConfig{ + User: config.User, + Timeout: DefaultTimeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + // 2. 密码方式 + if config.Password != "" { + clientConfig.Auth = append(clientConfig.Auth, ssh.Password(config.Password)) + } + + // 3. privite key file + if len(config.KeyFiles) == 0 { + keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" + if auth, err := WithPrivateKey(keyPath, config.Password); err != nil { + clientConfig.Auth = append(clientConfig.Auth, auth) + } + } else { + if auth, err := WithPrivateKeys(config.KeyFiles, config.Password); err != nil { + clientConfig.Auth = append(clientConfig.Auth, auth) + } + } + // 1. agent + if auth, err := WithAgent(); err != nil { + clientConfig.Auth = append(clientConfig.Auth, auth) + } + if config.Port == 0 { + config.Port = 22 + } + // hostPort := config.Host + ":" + strconv.Itoa(config.Port) + sshClient, err := ssh.Dial("tcp", net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), clientConfig) + + if err != nil { + return client, errors.New("Failed to dial ssh: " + err.Error()) + } + + // create sftp client + var sftpClient *sftp.Client + if sftpClient, err = sftp.NewClient(sshClient); err != nil { + return client, errors.New("Failed to conn sftp: " + err.Error()) + } + + return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil +} + +// Execute cmd on the remote host and return stderr and stdout +func (c *Client) Exec(cmd string) ([]byte, error) { + session, err := c.SSHClient.NewSession() + if err != nil { + return nil, err + } + defer session.Close() + + return session.CombinedOutput(cmd) +} + +// Close the underlying SSH connection +func (c *Client) Close() { + c.SFTPClient.Close() + c.SSHClient.Close() +} + +func addPortToHost(host string) string { + _, _, err := net.SplitHostPort(host) + + // We got an error so blindly try to add a port number + if err != nil { + return net.JoinHostPort(host, "22") + } + + return host +} diff --git a/cmd/main.go b/cmd/main.go new file mode 100644 index 0000000..81c80bb --- /dev/null +++ b/cmd/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + + "github.com/pytool/ssh" +) + +func main() { + + client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + if err != nil { + panic(err) + } + defer client.Close() + + output, err := client.Exec("uptime") + if err != nil { + panic(err) + } + + fmt.Printf("Uptime: %s\n", output) + +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..bdedbeb --- /dev/null +++ b/config.go @@ -0,0 +1,80 @@ +package ssh + +import ( + "time" +) + +type Config struct { + User string + Host string + Port int + Password string + KeyFiles []string + + // DisableAgentForwarding, if true, will not forward the SSH agent. + DisableAgentForwarding bool + + // HandshakeTimeout limits the amount of time we'll wait to handshake before + // saying the connection failed. + HandshakeTimeout time.Duration + + // KeepAliveInterval sets how often we send a channel request to the + // server. A value < 0 disables. + KeepAliveInterval time.Duration + + // Timeout is how long to wait for a read or write to succeed. + Timeout time.Duration +} + +var DefaultConfig = &Config{ + User: "root", + Port: 22, + KeyFiles: []string{"~/.ssh/id_rsa"}, +} + +// +func (c *Client) WithUser(user string) *Client { + if user == "" { + user = "root" + } + c.User = user + return c +} + +// +func (c *Client) WithHost(host string) *Client { + if host == "" { + host = "localhost" + } + c.Host = host + return c +} +func (c *Client) WithPassword(password string) *Client { + c.Password = password + return c +} + +// +func (c *Client) SetKeys(keyfiles []string) *Client { + if keyfiles == nil { + return c + } + t := make([]string, len(keyfiles)) + copy(t, keyfiles) + c.KeyFiles = t + return c +} + +// +func (c *Client) WithKey(keyfile string) *Client { + if keyfile == "" { + keyfile = "~/.ssh/id_rsa" + } + for _, s := range c.KeyFiles { + if s == keyfile { + return c + } + } + c.KeyFiles = append(c.KeyFiles, keyfile) + return c +} diff --git a/example/get/main.go b/example/get/main.go new file mode 100644 index 0000000..7358c96 --- /dev/null +++ b/example/get/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "github.com/pytool/ssh" +) + +func main() { + + client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + if err != nil { + panic(err) + } + defer client.Close() + var remotedir = "/root/test/" + // download dir + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/" + client.Download(remotedir, local) + + // upload file + var remotefile = "/root/test/file" + + client.Download(remotefile, local) + +} diff --git a/example/put/main.go b/example/put/main.go new file mode 100644 index 0000000..ef29e1b --- /dev/null +++ b/example/put/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "github.com/pytool/ssh" +) + +func main() { + + client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + if err != nil { + panic(err) + } + defer client.Close() + var remotedir = "/root/test/" + // upload dir + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/" + client.Upload(local, remotedir) + // upload file + local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/file" + client.Upload(local, remotedir) + +} diff --git a/sftp.go b/sftp.go new file mode 100644 index 0000000..c5e8538 --- /dev/null +++ b/sftp.go @@ -0,0 +1,399 @@ +package ssh + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "path" + "path/filepath" + "strings" +) + +// Upload 上传本地文件 local 到sftp远程目录 remote like rsync +// rsync -av src/ dst ./src/* --> /root/dst/* +// rsync -av src/ dst/ ./src/* --> /root/dst/* +// rsync -av src dst ./src/* --> /root/dst/src/* +// rsync -av src dst/ ./src/* --> /root/dst/src/* +func (c *Client) Upload(local string, remote string) (err error) { + // var localDir, localFile, remoteDir, remoteFile string + + info, err := os.Stat(local) + if err != nil { + return errors.New("本地文件不存在或格式错误 Upload(\"" + local + "\") 跳过上传") + } + if info.IsDir() { + return c.UploadDir(local, remote) + } + return c.UploadFile(local, remote) +} + +// Download 下载sftp远程文件 remote 到本地 local like rsync +func (c *Client) Download(remote string, local string) (err error) { + if c.IsNotExist(strings.TrimSuffix(remote, "/")) { + return errors.New("文件不存在,跳过文件下载 \"" + remote + "\" ") + } + if c.IsDir(remote) { + // return errors.New("检测到远程是文件不是目录 \"" + remote + "\" 跳过下载") + return c.downloadDir(remote, local) + + } + return c.downloadFile(remote, local) + +} + +// downloadFile a file from the remote server like cp +func (c *Client) downloadFile(remoteFile, local string) error { + // remoteFile = strings.TrimSuffix(remoteFile, "/") + if !c.IsFile(remoteFile) { + return errors.New("文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")") + } + var localFile string + if local[len(local)-1] == '/' { + localFile = filepath.Join(local, filepath.Base(remoteFile)) + } else { + localFile = local + } + + if err := os.MkdirAll(filepath.Dir(localFile), os.ModePerm); err != nil { + // fmt.Println(err) + return err + } + + r, err := c.SFTPClient.Open(remoteFile) + if err != nil { + return err + } + defer r.Close() + + l, err := os.Create(localFile) + if err != nil { + return err + } + defer l.Close() + + _, err = io.Copy(l, r) + return err +} + +// downloadDir from remote dir to local dir like rsync +// rsync -av src/ dst ./src/* --> /root/dst/* +// rsync -av src/ dst/ ./src/* --> /root/dst/* +// rsync -av src dst ./src/* --> /root/dst/src/* +// rsync -av src dst/ ./src/* --> /root/dst/src/* +func (c *Client) downloadDir(remote, local string) error { + var localDir, remoteDir string + + if !c.IsDir(remote) { + return errors.New("目录不存在或不是目录, 跳过 downloadDir(" + remote + ")") + } + remoteDir = remote + if remote[len(remote)-1] == '/' { + localDir = local + } else { + localDir = path.Join(local, path.Base(remote)) + } + + walker := c.SFTPClient.Walk(remoteDir) + + for walker.Step() { + if err := walker.Err(); err != nil { + fmt.Fprintln(os.Stderr, err) + continue + } + + info := walker.Stat() + + relPath, err := filepath.Rel(remoteDir, walker.Path()) + if err != nil { + return err + } + + localPath := filepath.ToSlash(filepath.Join(localDir, relPath)) + + // if we have something at the download path delete it if it is a directory + // and the remote is a file and vice a versa + localInfo, err := os.Stat(localPath) + if os.IsExist(err) { + if localInfo.IsDir() { + if info.IsDir() { + continue + } + + err = os.RemoveAll(localPath) + if err != nil { + return err + } + } else if info.IsDir() { + err = os.Remove(localPath) + if err != nil { + return err + } + } + } + + if info.IsDir() { + err = os.MkdirAll(localPath, os.ModePerm) + if err != nil { + return err + } + + continue + } + + c.downloadFile(walker.Path(), localPath) + + } + return nil +} + +//UploadFile 上传本地文件 localFile 到sftp远程目录 remote +func (c *Client) UploadFile(localFile, remote string) error { + // localFile = strings.TrimSuffix(localFile, "/") + info, err := os.Stat(localFile) + if err != nil || info.IsDir() { + return errors.New("本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传") + } + + l, err := os.Open(localFile) + if err != nil { + return err + } + defer l.Close() + + var remoteFile, remoteDir string + if remote[len(remote)-1] == '/' { + remoteFile = filepath.Join(remote, filepath.Base(localFile)) + remoteDir = remote + } else { + remoteFile = remote + remoteDir = filepath.Dir(remoteFile) + } + + // 目录不存在,则创建 remoteDir + if _, err := c.SFTPClient.Stat(remoteDir); err != nil { + c.MkdirAll(remoteDir) + } + + r, err := c.SFTPClient.Create(remoteFile) + if err != nil { + return err + } + + _, err = io.Copy(r, l) + return err +} + +// UploadDir files without checking diff status +func (c *Client) UploadDir(localDir string, remoteDir string) (err error) { + // defer func() { + // if err != nil { + // err = errors.New("UploadDir " + err.Error()) + // } + // }() + // 本地输入检测,必须是目录 + info, err := os.Stat(localDir) + if err != nil || !info.IsDir() { + return errors.New("本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传") + } + + // 模仿 rsync localDir不以'/'结尾,则创建尾目录 + if localDir[len(localDir)-1] != '/' { + remoteDir = filepath.Join(remoteDir, filepath.Base(localDir)) + } + // fmt.Println("remoteDir", remoteDir) + + rootDst := strings.TrimSuffix(remoteDir, "/") + if c.IsFile(rootDst) { + c.SFTPClient.Remove(rootDst) + } + + walkFunc := func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + // Calculate the final destination using the + // base source and root destination + relSrc, err := filepath.Rel(localDir, path) + if err != nil { + return err + } + finalDst := filepath.Join(rootDst, relSrc) + + // In Windows, Join uses backslashes which we don't want to get + // to the sftp server + finalDst = filepath.ToSlash(finalDst) + + // Skip the creation of the target destination directory since + // it should exist and we might not even own it + if finalDst == remoteDir { + return nil + fmt.Println("skip", remoteDir, "--->", finalDst) + + } + + if info.IsDir() { + c.MkdirAll(finalDst) + // err = c.SFTPClient.Mkdir(finalDst) + // fmt.Println(err) + // if err := c.SFTPClient.Mkdir(finalDst); err != nil { + // // Do not consider it an error if the directory existed + // remoteFi, fiErr := c.SFTPClient.Lstat(finalDst) + // if fiErr != nil || !remoteFi.IsDir() { + // return err + // } + // } + // return err + } else { + // f, err := os.Open(path) + // if err != nil { + // return err + // } + // defer f.Close() + return c.UploadFile(path, finalDst) + } + return nil + + } + return filepath.Walk(localDir, walkFunc) +} + +// Remove a file from the remote server +func (c *Client) Remove(path string) error { + return c.SFTPClient.Remove(path) +} + +// RemoveDirectory Remove a directory from the remote server +func (c *Client) RemoveDirectory(path string) error { + return c.SFTPClient.RemoveDirectory(path) +} + +// ReadAll Read a remote file and return the contents. +func (c *Client) ReadAll(filepath string) ([]byte, error) { + file, err := c.SFTPClient.Open(filepath) + if err != nil { + return nil, err + } + defer file.Close() + + return ioutil.ReadAll(file) +} + +//FileExist 文件是否存在 +func (c *Client) FileExist(filepath string) (bool, error) { + if _, err := c.SFTPClient.Stat(filepath); err != nil { + return false, err + } + return true, nil +} + +func (c *Client) RemoveFile(remoteFile string) error { + return c.SFTPClient.Remove(remoteFile) +} +func (c *Client) RemoveDir(remoteDir string) error { + remoteFiles, err := c.SFTPClient.ReadDir(remoteDir) + if err != nil { + log.Printf("remove remote dir: %s err: %v\n", remoteDir, err) + return err + } + for _, file := range remoteFiles { + subRemovePath := path.Join(remoteDir, file.Name()) + if file.IsDir() { + c.RemoveDir(subRemovePath) + } else { + c.RemoveFile(subRemovePath) + } + } + c.SFTPClient.RemoveDirectory(remoteDir) //must empty dir to remove + log.Printf("remove remote dir: %s ok\n", remoteDir) + return nil +} + +//RemoveAll 递归删除目录,文件 +func (c *Client) RemoveAll(remoteDir string) error { + c.RemoveDir(remoteDir) + return nil +} + +//MkdirAll 创建目录,递归 +func (c *Client) MkdirAll(dirpath string) error { + parentDir := filepath.Dir(dirpath) + _, err := c.SFTPClient.Stat(parentDir) + if err != nil { + if err.Error() == "file does not exist" { + err := c.MkdirAll(parentDir) + if err != nil { + return err + } + } else { + return err + } + } + err = c.SFTPClient.Mkdir(dirpath) + if err != nil { + return err + } + return nil +} + +func MkdirAll(path string) error { + // 检测文件夹是否存在 若不存在 创建文件夹 + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return os.MkdirAll(path, os.ModePerm) + } + } + return nil +} + +func (c *Client) Mkdir(path string, fi os.FileInfo) error { + log.Printf("[DEBUG] sftp: creating dir %s", path) + + if err := c.SFTPClient.Mkdir(path); err != nil { + // Do not consider it an error if the directory existed + remoteFi, fiErr := c.SFTPClient.Lstat(path) + if fiErr != nil || !remoteFi.IsDir() { + return err + } + } + + mode := fi.Mode().Perm() + if err := c.SFTPClient.Chmod(path, mode); err != nil { + return err + } + return nil +} + +//IsDir 检查远程是否是个目录 +func (c *Client) IsDir(path string) bool { + // 检查远程是文件还是目录 + info, err := c.SFTPClient.Stat(path) + if err == nil && info.IsDir() { + return true + } + return false +} + +//IsFile 检查远程是否是个文件 +func (c *Client) IsFile(path string) bool { + info, err := c.SFTPClient.Stat(path) + if err == nil && !info.IsDir() { + return true + } + return false +} + +//IsNotExist 检查远程是文件是否不存在 +func (c *Client) IsNotExist(path string) bool { + _, err := c.SFTPClient.Stat(path) + return err != nil +} + +//IsExist 检查远程是文件是否存在 +func (c *Client) IsExist(path string) bool { + + _, err := c.SFTPClient.Stat(path) + return err == nil +} diff --git a/sftp_is_test.go b/sftp_is_test.go new file mode 100644 index 0000000..35d68d7 --- /dev/null +++ b/sftp_is_test.go @@ -0,0 +1,61 @@ +package ssh + +import ( + "testing" +) + +func TestClient_IsCheck(t *testing.T) { + c := GetClient() + defer c.Close() + var remotes = []string{ + "/root/test/notExist", + "/root/test/notExist/", + "/root/test/file", + "/root/test/file/", // 不存在 + "/root/test/dir", + "/root/test/dir/", + } + + // /root/test/file 存在 + // /root/test/file/ 不存在 + // /root/test/dir 存在 + // /root/test/dir/ 存在 + for _, v := range remotes { + is := c.IsExist(v) + if is { + println(v, "\t存在") + } else { + println(v, "\t不存在") + } + } + + // /root/test/file 不是一个目录 + // /root/test/file/ 不是一个目录 + // /root/test/dir 是一个目录 + // /root/test/dir/ 是一个目录 + println() + for _, v := range remotes { + isdir := c.IsDir(v) + if isdir { + println(v, "\t是一个目录") + } else { + println(v, "\t不是一个目录") + } + } + + // /root/test/file 是一个文件 + // /root/test/file/ 不是一个文件 + // /root/test/dir 不是一个文件 + // /root/test/dir/ 不是一个文件 + println() + for _, v := range remotes { + isfile := c.IsFile(v) + if isfile { + println(v, "\t是一个文件") + } else { + println(v, "\t不是一个文件") + } + + } + +} diff --git a/sftp_test.go b/sftp_test.go new file mode 100644 index 0000000..d5c4c93 --- /dev/null +++ b/sftp_test.go @@ -0,0 +1,193 @@ +package ssh + +import ( + "sync" + "testing" +) + +func GetClient() *Client { + var ( + once = sync.Once{} + c = &Client{} + err error + ) + once.Do(func() { + c, err = NewClient("root", "localhost", "22", "ubuntu") + }) + if err != nil { + panic(err) + } + return c +} + +// func TestClient_RemoveAll(t *testing.T) { +// c := GetClient() +// defer c.Close() +// var remotedir = "/root/test/" +// fmt.Println(c.RemoveAll(remotedir)) +// } +func TestClient_Init(t *testing.T) { + c := GetClient() + defer c.Close() + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/" + + var remotedir = "/root/test/" + c.RemoveAll("/root/upload/") + + err := c.Upload(local, remotedir) + if err != nil { + println("[Upload]", err.Error()) + } +} +func TestClient_Upload(t *testing.T) { + c := GetClient() + defer c.Close() + + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/" + var uploads = map[string][]string{ + local + "null/": []string{"/root/upload/test/null/1", "/root/upload/test/null/2/"}, + local + "null/": []string{"/root/upload/test/null/3", "/root/upload/test/null/4/"}, + local + "file": []string{"/root/upload/test/file/1", "/root/upload/test/file/2/"}, + local + "file/": []string{"/root/upload/test/file/3", "/root/upload/test/file/4/"}, + local + "dir": []string{"/root/upload/test/dir/1", "/root/upload/test/dir/2/"}, + local + "dir/": []string{"/root/upload/test/dir/3", "/root/upload/test/dir/4/"}, + } + + for local, remotes := range uploads { + for _, remote := range remotes { + err := c.Upload(local, remote) + if err != nil { + println(err.Error()) + } + // println(remote, "--->", v, "Finish download!") + + } + } +} + +func TestClient_Download(t *testing.T) { + c := GetClient() + defer c.Close() + + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download" + var download = map[string][]string{ + "/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"}, + "/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"}, + "/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"}, + "/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"}, + "/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"}, + "/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"}, + } + + for remote, local := range download { + for _, v := range local { + err := c.Download(remote, v) + if err != nil { + println(err.Error()) + } + // println(remote, "--->", v, "Finish download!") + + } + } + +} + +func TestClient_DownloadFile(t *testing.T) { + c := GetClient() + defer c.Close() + + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/downloadfile" + var download = map[string][]string{ + "/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"}, + "/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"}, + "/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"}, + "/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"}, + "/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"}, + "/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"}, + } + + for remote, local := range download { + for _, v := range local { + err := c.downloadFile(remote, v) + if err != nil { + println(err.Error()) + } + // println(remote, "--->", v, "Finish download!") + + } + } +} +func TestClient_DownloadDir(t *testing.T) { + c := GetClient() + defer c.Close() + + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/downloaddir" + var download = map[string][]string{ + "/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"}, + "/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"}, + "/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"}, + "/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"}, + "/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"}, + "/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"}, + } + + for remote, local := range download { + for _, v := range local { + err := c.downloadDir(remote, v) + if err != nil { + println(err.Error()) + } + // println(remote, "--->", v, "Finish download!") + + } + } +} +func TestClient_UploadFile(t *testing.T) { + c := GetClient() + defer c.Close() + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/" + var uploads = map[string][]string{ + local + "null": []string{"/root/upload/file_test/null/1", "/root/upload/file_test/null/2/"}, + local + "null/": []string{"/root/upload/file_test/null/3", "/root/upload/file_test/null/4/"}, + local + "file": []string{"/root/upload/file_test/file/1", "/root/upload/file_test/file/2/"}, + local + "file/": []string{"/root/upload/file_test/file/3", "/root/upload/file_test/file/4/"}, + local + "dir": []string{"/root/upload/file_test/dir/1", "/root/upload/file_test/dir/2/"}, + local + "dir/": []string{"/root/upload/file_test/dir/3", "/root/upload/file_test/dir/4/"}, + } + + for local, remotes := range uploads { + for _, remote := range remotes { + err := c.UploadFile(local, remote) + if err != nil { + println(err.Error()) + } + // println(remote, "--->", v, "Finish download!") + + } + } +} + +func TestClient_UploadDir(t *testing.T) { + c := GetClient() + defer c.Close() + var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/" + var uploads = map[string][]string{ + local + "null/": []string{"/root/upload/dir_test/null/1", "/root/upload/dir_test/null/2/"}, + local + "null/": []string{"/root/upload/dir_test/null/3", "/root/upload/dir_test/null/4/"}, + local + "file": []string{"/root/upload/dir_test/file/1", "/root/upload/dir_test/file/2/"}, + local + "file/": []string{"/root/upload/dir_test/file/3", "/root/upload/dir_test/file/4/"}, + local + "dir": []string{"/root/upload/dir_test/dir/1", "/root/upload/dir_test/dir/2/"}, + local + "dir/": []string{"/root/upload/dir_test/dir/3", "/root/upload/dir_test/dir/4/"}, + } + + for local, remotes := range uploads { + for _, remote := range remotes { + err := c.UploadDir(local, remote) + if err != nil { + println(err.Error()) + } + // println(remote, "--->", v, "Finish download!") + + } + } +} diff --git a/sudo.go b/sudo.go new file mode 100644 index 0000000..344cfa8 --- /dev/null +++ b/sudo.go @@ -0,0 +1,70 @@ +package ssh + +import ( + "bytes" + "io" + "sync" +) + +// This is the phrase that tells us sudo is looking for a password via stdin +const sudoPwPrompt = "sudo_password" + +// sudoWriter is used to both combine stdout and stderr as well as +// look for a password request from sudo. +type sudoWriter struct { + b bytes.Buffer + pw string // The password to pass to sudo (if requested) + stdin io.Writer // The writer from the ssh session + m sync.Mutex +} + +func (w *sudoWriter) Write(p []byte) (int, error) { + // If we get the sudo password prompt phrase send the password via stdin + // and don't write it to the buffer. + if string(p) == sudoPwPrompt { + w.stdin.Write([]byte(w.pw + "\n")) + w.pw = "" // We don't need the password anymore so reset the string + return len(p), nil + } + + w.m.Lock() + defer w.m.Unlock() + + return w.b.Write(p) +} + +// ExecSu Execute cmd via sudo. Do not include the sudo command in +// the cmd string. For example: Client.ExecSudo("uptime", "password"). +// If you are using passwordless sudo you can use the regular Exec() +// function. +func (c *Client) ExecSu(cmd, passwd string) ([]byte, error) { + session, err := c.SSHClient.NewSession() + if err != nil { + return nil, err + } + defer session.Close() + + // -n run non interactively + // -p specify the prompt. We do this to know that sudo is asking for a passwd + // -S Writes the prompt to StdErr and reads the password from StdIn + cmd = "sudo -p " + sudoPwPrompt + " -S " + cmd + + // Use the sudoRW struct to handle the interaction with sudo and capture the + // output of the command + w := &sudoWriter{ + pw: passwd, + } + w.stdin, err = session.StdinPipe() + if err != nil { + return nil, err + } + + // Combine stdout, stderr to the same writer which also looks for the sudo + // password prompt + session.Stdout = w + session.Stderr = w + + err = session.Run(cmd) + + return w.b.Bytes(), err +} diff --git a/test/upload/dir/file b/test/upload/dir/file new file mode 100644 index 0000000..0e0d662 --- /dev/null +++ b/test/upload/dir/file @@ -0,0 +1 @@ +dir file \ No newline at end of file diff --git a/test/upload/dir/subfold/subfile b/test/upload/dir/subfold/subfile new file mode 100644 index 0000000..ef9c19f --- /dev/null +++ b/test/upload/dir/subfold/subfile @@ -0,0 +1 @@ +this is subdir file \ No newline at end of file diff --git a/test/upload/file b/test/upload/file new file mode 100644 index 0000000..1a010b1 --- /dev/null +++ b/test/upload/file @@ -0,0 +1 @@ +file \ No newline at end of file