diff --git a/README.md b/README.md index 92b16e7..4e5845a 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,15 @@ 本项目是对标准库进行一个简单的高层封装,使得可以在在 Windows Linux Mac 上非常容易的执行 ssh 命令, 以及文件,文件夹的上传,下载等操作. - -文件上传下载模仿rsync方式: 只和源有关. +1. 当src 为目录时 +文件上传下载模仿rsync: 只和源有关. // rsync -av src/ dst ./src/* --> /root/dst/* // rsync -av src/ dst/ ./src/* --> /root/dst/* // rsync -av src dst ./src/* --> /root/dst/src/* // rsync -av src dst/ ./src/* --> /root/dst/src/* - +2. 当src 为文件时 +当dst为目录,以"/"结尾,则自动拼接上文件名 +当dst为文件,不以“/”结尾时,则重命名文件 ## Install `go get github.com/pytool/ssh` ## Example @@ -24,7 +26,7 @@ import ( ) func main() { - c, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + c, err := ssh.NewClient("localhost", "22", "root", "ubuntu") if err != nil { panic(err) } @@ -49,7 +51,7 @@ import ( func main() { - client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + client, err := ssh.NewClient( "localhost", "22", "root", "ubuntu") if err != nil { panic(err) } @@ -77,7 +79,7 @@ import ( func main() { - client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + client, err := ssh.NewClient( "localhost", "22", "root", "ubuntu") if err != nil { panic(err) } diff --git a/auth.go b/auth.go index c78940d..1debd2a 100644 --- a/auth.go +++ b/auth.go @@ -52,18 +52,18 @@ func passwordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallen } } -// WithKeyboardPassword Generate a password-auth'd ssh ClientConfig -func WithKeyboardPassword(password string) (ssh.AuthMethod, error) { +// AuthWithKeyboardPassword Generate a password-auth'd ssh ClientConfig +func AuthWithKeyboardPassword(password string) (ssh.AuthMethod, error) { return ssh.KeyboardInteractive(passwordKeyboardInteractive(password)), nil } -// WithPassword Generate a password-auth'd ssh ClientConfig -func WithPassword(password string) (ssh.AuthMethod, error) { +// AuthWithPassword Generate a password-auth'd ssh ClientConfig +func AuthWithPassword(password string) (ssh.AuthMethod, error) { return ssh.Password(password), nil } -// WithAgent use already authed user -func WithAgent() (ssh.AuthMethod, error) { +// AuthWithAgent use already authed user +func AuthWithAgent() (ssh.AuthMethod, error) { sock := os.Getenv("SSH_AUTH_SOCK") if sock == "" { // fmt.Println(errors.New("Agent Disabled")) @@ -89,8 +89,8 @@ func WithAgent() (ssh.AuthMethod, error) { // return nil } -// WithPrivateKeys 设置多个 ~/.ssh/id_rsa ,如果加密用passphrase尝试 -func WithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, error) { +// AuthWithPrivateKeys 设置多个 ~/.ssh/id_rsa ,如果加密用passphrase尝试 +func AuthWithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, error) { var signers []ssh.Signer for _, key := range keyFiles { @@ -118,8 +118,8 @@ func WithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, erro return ssh.PublicKeys(signers...), nil } -// WithPrivateKey 自动监测是否带有密码 -func WithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) { +// AuthWithPrivateKey 自动监测是否带有密码 +func AuthWithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) { pemBytes, err := ioutil.ReadFile(keyfile) if err != nil { println(err.Error()) @@ -141,8 +141,8 @@ func WithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) { } -// WithPrivateKeyString 直接通过字符串 -func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) { +// AuthWithPrivateKeyString 直接通过字符串 +func AuthWithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) { var signer ssh.Signer var err error if password == "" { @@ -157,8 +157,8 @@ func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) { return ssh.PublicKeys(signer), nil } -// WithPrivateKeyTerminal 通过终端读取带密码的 PublicKey -func WithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) { +// AuthWithPrivateKeyTerminal 通过终端读取带密码的 PublicKey +func AuthWithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) { pemBytes, err := ioutil.ReadFile(keyfile) if err != nil { diff --git a/client.go b/client.go index f52829e..bb26658 100644 --- a/client.go +++ b/client.go @@ -35,11 +35,11 @@ func New(cnf *Config) (client *Client, err error) { // 1. privite key file if len(cnf.KeyFiles) == 0 { keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" - if auth, err := WithPrivateKey(keyPath, cnf.Passphrase); err == nil { + if auth, err := AuthWithPrivateKey(keyPath, cnf.Passphrase); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } } else { - if auth, err := WithPrivateKeys(cnf.KeyFiles, cnf.Passphrase); err == nil { + if auth, err := AuthWithPrivateKeys(cnf.KeyFiles, cnf.Passphrase); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } } @@ -48,7 +48,7 @@ func New(cnf *Config) (client *Client, err error) { clientConfig.Auth = append(clientConfig.Auth, ssh.Password(cnf.Password)) } // 3. agent 模式放在最后,这样当前两者都不能使用时可以采用Agent模式 - if auth, err := WithAgent(); err == nil { + if auth, err := AuthWithAgent(); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } @@ -94,7 +94,7 @@ func NewWithAgent(Host, Port, User string) (client *Client, err error) { Timeout: DefaultTimeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - auth, err := WithAgent() + auth, err := AuthWithAgent() if err != nil { return nil, err } @@ -108,7 +108,7 @@ func NewWithAgent(Host, Port, User string) (client *Client, err error) { // create sftp client var sftpClient *sftp.Client - if sftpClient, err = sftp.NewClient(sshClient); err != nil { + if sftpClient, err = sftp.NewClient(sshClient, sftp.MaxPacket(10240000)); err != nil { return client, errors.New("Failed to conn sftp: " + err.Error()) } return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil @@ -122,7 +122,7 @@ func NewWithPrivateKey(Host, Port, User, Passphrase string) (client *Client, err } // 3. privite key file keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" - auth, err := WithPrivateKey(keyPath, Passphrase) + auth, err := AuthWithPrivateKey(keyPath, Passphrase) if err != nil { fmt.Println(err) return nil, err diff --git a/cmd/main.go b/cmd/main.go index 81c80bb..560a304 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -7,8 +7,9 @@ import ( ) func main() { - - client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + config := ssh.Default.WithPassword("ubuntu") + client, err := ssh.New(config) + // client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") if err != nil { panic(err) } diff --git a/config.go b/config.go index a6cdaf2..37a3ae0 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,8 @@ package ssh import ( + "os" + "path" "time" ) @@ -29,48 +31,51 @@ type Config struct { } var DefaultConfig = &Config{ - User: "root", - Port: 22, - KeyFiles: []string{"~/.ssh/id_rsa"}, + Host: "localhost", + Port: 22, + User: "root", + // KeyFiles: []string{path.Join(os.Getenv("HOME"), "/.ssh/id_rsa")}, } +var Default = DefaultConfig -// -func (c *Client) WithUser(user string) *Client { +func WithUser(user string) *Config { + return Default.WithUser(user) +} +func (c *Config) WithUser(user string) *Config { if user == "" { user = "root" } c.User = user return c } +func WithHost(host string) *Config { + return Default.WithHost(host) +} -// -func (c *Client) WithHost(host string) *Client { +func (c *Config) WithHost(host string) *Config { if host == "" { host = "localhost" } c.Host = host return c } -func (c *Client) WithPassword(password string) *Client { + +func WithPassword(password string) *Config { + return Default.WithPassword(password) +} +func (c *Config) WithPassword(password string) *Config { c.Password = password return c } -// -func (c *Client) SetKeys(keyfiles []string) *Client { - if keyfiles == nil { - return c - } - t := make([]string, len(keyfiles)) - copy(t, keyfiles) - c.KeyFiles = t - return c +func WithKey(keyfile, passphrase string) *Config { + return Default.WithKey(keyfile, passphrase) } - -// -func (c *Client) WithKey(keyfile string) *Client { +func (c *Config) WithKey(keyfile, passphrase string) *Config { if keyfile == "" { - keyfile = "~/.ssh/id_rsa" + if home := os.Getenv("HOME"); home != "" { + keyfile = path.Join(home, "/.ssh/id_rsa") + } } for _, s := range c.KeyFiles { if s == keyfile { @@ -80,3 +85,14 @@ func (c *Client) WithKey(keyfile string) *Client { c.KeyFiles = append(c.KeyFiles, keyfile) return c } + +// +func (c *Config) SetKeys(keyfiles []string) *Config { + if keyfiles == nil { + return c + } + t := make([]string, len(keyfiles)) + copy(t, keyfiles) + c.KeyFiles = t + return c +} diff --git a/example/get/main.go b/example/get/main.go index 7358c96..cb3c772 100644 --- a/example/get/main.go +++ b/example/get/main.go @@ -6,7 +6,7 @@ import ( func main() { - client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") if err != nil { panic(err) } diff --git a/example/put/main.go b/example/put/main.go index ef29e1b..ed22443 100644 --- a/example/put/main.go +++ b/example/put/main.go @@ -6,7 +6,7 @@ import ( func main() { - client, err := ssh.NewClient("root", "localhost", "22", "ubuntu") + client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") if err != nil { panic(err) } diff --git a/sftp.go b/sftp.go index c5e8538..80bd23a 100644 --- a/sftp.go +++ b/sftp.go @@ -1,8 +1,8 @@ package ssh import ( + "bytes" "errors" - "fmt" "io" "io/ioutil" "log" @@ -22,9 +22,10 @@ func (c *Client) Upload(local string, remote string) (err error) { info, err := os.Stat(local) if err != nil { - return errors.New("本地文件不存在或格式错误 Upload(\"" + local + "\") 跳过上传") + return errors.New("sftp: 跳过上传 Upload(\"" + local + "\") ,本地文件不存在或格式错误!") } if info.IsDir() { + log.Println("sftp: UploadDir", local) return c.UploadDir(local, remote) } return c.UploadFile(local, remote) @@ -33,7 +34,7 @@ func (c *Client) Upload(local string, remote string) (err error) { // Download 下载sftp远程文件 remote 到本地 local like rsync func (c *Client) Download(remote string, local string) (err error) { if c.IsNotExist(strings.TrimSuffix(remote, "/")) { - return errors.New("文件不存在,跳过文件下载 \"" + remote + "\" ") + return errors.New("sftp: 远程文件不存在,跳过文件下载 \"" + remote + "\" ") } if c.IsDir(remote) { // return errors.New("检测到远程是文件不是目录 \"" + remote + "\" 跳过下载") @@ -48,7 +49,7 @@ func (c *Client) Download(remote string, local string) (err error) { func (c *Client) downloadFile(remoteFile, local string) error { // remoteFile = strings.TrimSuffix(remoteFile, "/") if !c.IsFile(remoteFile) { - return errors.New("文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")") + return errors.New("sftp: 文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")") } var localFile string if local[len(local)-1] == '/' { @@ -56,9 +57,22 @@ func (c *Client) downloadFile(remoteFile, local string) error { } else { localFile = local } - + if c.Size(remoteFile) > 1000 { + rsum := c.Md5File(remoteFile) + ioutil.WriteFile(localFile+".md5", []byte(rsum), 755) + if FileExist(localFile) { + // 1. 检测远程是否存在 + if rsum != "" { + lsum, _ := Md5File(localFile) + if lsum == rsum { + log.Println("sftp: 文件与本地一致,跳过上传!", localFile) + return nil + } + } + } + } if err := os.MkdirAll(filepath.Dir(localFile), os.ModePerm); err != nil { - // fmt.Println(err) + // log.Println(err) return err } @@ -87,7 +101,7 @@ func (c *Client) downloadDir(remote, local string) error { var localDir, remoteDir string if !c.IsDir(remote) { - return errors.New("目录不存在或不是目录, 跳过 downloadDir(" + remote + ")") + return errors.New("sftp: 目录不存在或不是目录, 跳过 downloadDir(" + remote + ")") } remoteDir = remote if remote[len(remote)-1] == '/' { @@ -100,7 +114,7 @@ func (c *Client) downloadDir(remote, local string) error { for walker.Step() { if err := walker.Err(); err != nil { - fmt.Fprintln(os.Stderr, err) + log.Println(err) continue } @@ -152,9 +166,10 @@ func (c *Client) downloadDir(remote, local string) error { //UploadFile 上传本地文件 localFile 到sftp远程目录 remote func (c *Client) UploadFile(localFile, remote string) error { // localFile = strings.TrimSuffix(localFile, "/") + // localFile = filepath.ToSlash(localFile) info, err := os.Stat(localFile) if err != nil || info.IsDir() { - return errors.New("本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传") + return errors.New("sftp: 本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传") } l, err := os.Open(localFile) @@ -165,15 +180,28 @@ func (c *Client) UploadFile(localFile, remote string) error { var remoteFile, remoteDir string if remote[len(remote)-1] == '/' { - remoteFile = filepath.Join(remote, filepath.Base(localFile)) + remoteFile = filepath.ToSlash(filepath.Join(remote, filepath.Base(localFile))) remoteDir = remote } else { remoteFile = remote - remoteDir = filepath.Dir(remoteFile) + remoteDir = filepath.ToSlash(filepath.Dir(remoteFile)) + } + log.Println("sftp: UploadFile", localFile, remoteFile) + if info.Size() > 1000 { + // 1. 检测远程是否存在 + rsum := c.Md5File(remoteFile) + if rsum != "" { + lsum, _ := Md5File(localFile) + if lsum == rsum { + log.Println("sftp: 文件与本地一致,跳过上传!", localFile) + return nil + } + } } // 目录不存在,则创建 remoteDir if _, err := c.SFTPClient.Stat(remoteDir); err != nil { + log.Println("sftp: Mkdir all", remoteDir) c.MkdirAll(remoteDir) } @@ -194,16 +222,17 @@ func (c *Client) UploadDir(localDir string, remoteDir string) (err error) { // } // }() // 本地输入检测,必须是目录 + // localDir = filepath.ToSlash(localDir) info, err := os.Stat(localDir) if err != nil || !info.IsDir() { - return errors.New("本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传") + return errors.New("sftp: 本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传") } // 模仿 rsync localDir不以'/'结尾,则创建尾目录 if localDir[len(localDir)-1] != '/' { - remoteDir = filepath.Join(remoteDir, filepath.Base(localDir)) + remoteDir = filepath.ToSlash(filepath.Join(remoteDir, filepath.Base(localDir))) } - // fmt.Println("remoteDir", remoteDir) + log.Println("sftp: UploadDir", localDir, remoteDir) rootDst := strings.TrimSuffix(remoteDir, "/") if c.IsFile(rootDst) { @@ -230,14 +259,18 @@ func (c *Client) UploadDir(localDir string, remoteDir string) (err error) { // it should exist and we might not even own it if finalDst == remoteDir { return nil - fmt.Println("skip", remoteDir, "--->", finalDst) + log.Println("sftp: ", remoteDir, "--->", finalDst) } if info.IsDir() { - c.MkdirAll(finalDst) + err := c.MkdirAll(finalDst) + if err != nil { + log.Println("sftp: MkdirAll", err) + } + // log.Println("MkdirAll", finalDst) // err = c.SFTPClient.Mkdir(finalDst) - // fmt.Println(err) + // log.Println(err) // if err := c.SFTPClient.Mkdir(finalDst); err != nil { // // Do not consider it an error if the directory existed // remoteFi, fiErr := c.SFTPClient.Lstat(finalDst) @@ -295,7 +328,7 @@ func (c *Client) RemoveFile(remoteFile string) error { func (c *Client) RemoveDir(remoteDir string) error { remoteFiles, err := c.SFTPClient.ReadDir(remoteDir) if err != nil { - log.Printf("remove remote dir: %s err: %v\n", remoteDir, err) + log.Printf("sftp: remove remote dir: %s err: %v\n", remoteDir, err) return err } for _, file := range remoteFiles { @@ -307,7 +340,7 @@ func (c *Client) RemoveDir(remoteDir string) error { } } c.SFTPClient.RemoveDirectory(remoteDir) //must empty dir to remove - log.Printf("remove remote dir: %s ok\n", remoteDir) + log.Printf("sftp: remove remote dir: %s ok\n", remoteDir) return nil } @@ -319,9 +352,11 @@ func (c *Client) RemoveAll(remoteDir string) error { //MkdirAll 创建目录,递归 func (c *Client) MkdirAll(dirpath string) error { - parentDir := filepath.Dir(dirpath) + + parentDir := filepath.ToSlash(filepath.Dir(dirpath)) _, err := c.SFTPClient.Stat(parentDir) if err != nil { + // log.Println(err) if err.Error() == "file does not exist" { err := c.MkdirAll(parentDir) if err != nil { @@ -331,23 +366,13 @@ func (c *Client) MkdirAll(dirpath string) error { return err } } - err = c.SFTPClient.Mkdir(dirpath) + err = c.SFTPClient.Mkdir(filepath.ToSlash(dirpath)) if err != nil { return err } return nil } -func MkdirAll(path string) error { - // 检测文件夹是否存在 若不存在 创建文件夹 - if _, err := os.Stat(path); err != nil { - if os.IsNotExist(err) { - return os.MkdirAll(path, os.ModePerm) - } - } - return nil -} - func (c *Client) Mkdir(path string, fi os.FileInfo) error { log.Printf("[DEBUG] sftp: creating dir %s", path) @@ -376,6 +401,15 @@ func (c *Client) IsDir(path string) bool { return false } +//Size 获取文件大小 +func (c *Client) Size(path string) int64 { + info, err := c.SFTPClient.Stat(path) + if err != nil { + return 0 + } + return info.Size() +} + //IsFile 检查远程是否是个文件 func (c *Client) IsFile(path string) bool { info, err := c.SFTPClient.Stat(path) @@ -397,3 +431,16 @@ func (c *Client) IsExist(path string) bool { _, err := c.SFTPClient.Stat(path) return err == nil } + +//Md5File 检查远程是文件是否存在 +func (c *Client) Md5File(path string) string { + if c.IsNotExist(path) { + return "" + } + b, err := c.Run("md5sum " + path) + if err != nil { + return "" + } + return string(bytes.Split(b, []byte{' '})[0]) + +} diff --git a/sftp_test.go b/sftp_test.go index d5c4c93..d5e8ee7 100644 --- a/sftp_test.go +++ b/sftp_test.go @@ -12,7 +12,7 @@ func GetClient() *Client { err error ) once.Do(func() { - c, err = NewClient("root", "localhost", "22", "ubuntu") + c, err = NewClient("localhost", "22", "root", "ubuntu") }) if err != nil { panic(err) diff --git a/ssh.go b/ssh.go index 1539d1b..78d5644 100644 --- a/ssh.go +++ b/ssh.go @@ -39,7 +39,7 @@ func (c *Client) RunScript(scriptPath string) ([]byte, error) { defer session.Close() // 1. 上传 script - remotePath := fmt.Sprintf("/tmp/%s", filepath.Base(scriptPath)) + remotePath := fmt.Sprintf("/tmp/script/%s", filepath.Base(scriptPath)) if err := c.UploadFile(scriptPath, remotePath); err != nil { return nil, err } diff --git a/util.go b/util.go new file mode 100644 index 0000000..da26b8b --- /dev/null +++ b/util.go @@ -0,0 +1,47 @@ +package ssh + +import ( + "bufio" + "crypto" + "encoding/hex" + "io" + "os" +) + +func FileExist(file string) bool { + if _, err := os.Stat(file); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} +func MkdirAll(path string) error { + // 检测文件夹是否存在 若不存在 创建文件夹 + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return os.MkdirAll(path, os.ModePerm) + } + } + return nil +} + +//Md5File 计算md5 +func Md5File(filename string) (string, error) { + f, err := os.Open(filename) + if err != nil { + return "", err + } + defer f.Close() + + r := bufio.NewReader(f) + + hash := crypto.MD5.New() + _, err = io.Copy(hash, r) + if err != nil { + return "", err + } + + out := hex.EncodeToString(hash.Sum(nil)) + return out, nil +}