fix uploadfile for windows&linux

This commit is contained in:
rinetd 2018-11-13 17:42:29 +08:00
parent 375b34f1ea
commit 0e936a4f58
11 changed files with 197 additions and 84 deletions

View File

@ -4,13 +4,15 @@
本项目是对标准库进行一个简单的高层封装,使得可以在在 Windows Linux Mac 上非常容易的执行 ssh 命令,
以及文件,文件夹的上传,下载等操作.
文件上传下载模仿rsync方式: 只和源有关.
1. 当src 为目录时
文件上传下载模仿rsync: 只和源有关.
// rsync -av src/ dst ./src/* --> /root/dst/*
// rsync -av src/ dst/ ./src/* --> /root/dst/*
// rsync -av src dst ./src/* --> /root/dst/src/*
// rsync -av src dst/ ./src/* --> /root/dst/src/*
2. 当src 为文件时
当dst为目录以"/"结尾,则自动拼接上文件名
当dst为文件不以“/”结尾时,则重命名文件
## Install
`go get github.com/pytool/ssh`
## Example
@ -24,7 +26,7 @@ import (
)
func main() {
c, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
c, err := ssh.NewClient("localhost", "22", "root", "ubuntu")
if err != nil {
panic(err)
}
@ -49,7 +51,7 @@ import (
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
client, err := ssh.NewClient( "localhost", "22", "root", "ubuntu")
if err != nil {
panic(err)
}
@ -77,7 +79,7 @@ import (
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
client, err := ssh.NewClient( "localhost", "22", "root", "ubuntu")
if err != nil {
panic(err)
}

28
auth.go
View File

@ -52,18 +52,18 @@ func passwordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallen
}
}
// WithKeyboardPassword Generate a password-auth'd ssh ClientConfig
func WithKeyboardPassword(password string) (ssh.AuthMethod, error) {
// AuthWithKeyboardPassword Generate a password-auth'd ssh ClientConfig
func AuthWithKeyboardPassword(password string) (ssh.AuthMethod, error) {
return ssh.KeyboardInteractive(passwordKeyboardInteractive(password)), nil
}
// WithPassword Generate a password-auth'd ssh ClientConfig
func WithPassword(password string) (ssh.AuthMethod, error) {
// AuthWithPassword Generate a password-auth'd ssh ClientConfig
func AuthWithPassword(password string) (ssh.AuthMethod, error) {
return ssh.Password(password), nil
}
// WithAgent use already authed user
func WithAgent() (ssh.AuthMethod, error) {
// AuthWithAgent use already authed user
func AuthWithAgent() (ssh.AuthMethod, error) {
sock := os.Getenv("SSH_AUTH_SOCK")
if sock == "" {
// fmt.Println(errors.New("Agent Disabled"))
@ -89,8 +89,8 @@ func WithAgent() (ssh.AuthMethod, error) {
// return nil
}
// WithPrivateKeys 设置多个 ~/.ssh/id_rsa ,如果加密用passphrase尝试
func WithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, error) {
// AuthWithPrivateKeys 设置多个 ~/.ssh/id_rsa ,如果加密用passphrase尝试
func AuthWithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, error) {
var signers []ssh.Signer
for _, key := range keyFiles {
@ -118,8 +118,8 @@ func WithPrivateKeys(keyFiles []string, passphrase string) (ssh.AuthMethod, erro
return ssh.PublicKeys(signers...), nil
}
// WithPrivateKey 自动监测是否带有密码
func WithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) {
// AuthWithPrivateKey 自动监测是否带有密码
func AuthWithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) {
pemBytes, err := ioutil.ReadFile(keyfile)
if err != nil {
println(err.Error())
@ -141,8 +141,8 @@ func WithPrivateKey(keyfile string, passphrase string) (ssh.AuthMethod, error) {
}
// WithPrivateKeyString 直接通过字符串
func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) {
// AuthWithPrivateKeyString 直接通过字符串
func AuthWithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) {
var signer ssh.Signer
var err error
if password == "" {
@ -157,8 +157,8 @@ func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) {
return ssh.PublicKeys(signer), nil
}
// WithPrivateKeyTerminal 通过终端读取带密码的 PublicKey
func WithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) {
// AuthWithPrivateKeyTerminal 通过终端读取带密码的 PublicKey
func AuthWithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) {
pemBytes, err := ioutil.ReadFile(keyfile)
if err != nil {

View File

@ -35,11 +35,11 @@ func New(cnf *Config) (client *Client, err error) {
// 1. privite key file
if len(cnf.KeyFiles) == 0 {
keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
if auth, err := WithPrivateKey(keyPath, cnf.Passphrase); err == nil {
if auth, err := AuthWithPrivateKey(keyPath, cnf.Passphrase); err == nil {
clientConfig.Auth = append(clientConfig.Auth, auth)
}
} else {
if auth, err := WithPrivateKeys(cnf.KeyFiles, cnf.Passphrase); err == nil {
if auth, err := AuthWithPrivateKeys(cnf.KeyFiles, cnf.Passphrase); err == nil {
clientConfig.Auth = append(clientConfig.Auth, auth)
}
}
@ -48,7 +48,7 @@ func New(cnf *Config) (client *Client, err error) {
clientConfig.Auth = append(clientConfig.Auth, ssh.Password(cnf.Password))
}
// 3. agent 模式放在最后,这样当前两者都不能使用时可以采用Agent模式
if auth, err := WithAgent(); err == nil {
if auth, err := AuthWithAgent(); err == nil {
clientConfig.Auth = append(clientConfig.Auth, auth)
}
@ -94,7 +94,7 @@ func NewWithAgent(Host, Port, User string) (client *Client, err error) {
Timeout: DefaultTimeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
auth, err := WithAgent()
auth, err := AuthWithAgent()
if err != nil {
return nil, err
}
@ -108,7 +108,7 @@ func NewWithAgent(Host, Port, User string) (client *Client, err error) {
// create sftp client
var sftpClient *sftp.Client
if sftpClient, err = sftp.NewClient(sshClient); err != nil {
if sftpClient, err = sftp.NewClient(sshClient, sftp.MaxPacket(10240000)); err != nil {
return client, errors.New("Failed to conn sftp: " + err.Error())
}
return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil
@ -122,7 +122,7 @@ func NewWithPrivateKey(Host, Port, User, Passphrase string) (client *Client, err
}
// 3. privite key file
keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
auth, err := WithPrivateKey(keyPath, Passphrase)
auth, err := AuthWithPrivateKey(keyPath, Passphrase)
if err != nil {
fmt.Println(err)
return nil, err

View File

@ -7,8 +7,9 @@ import (
)
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
config := ssh.Default.WithPassword("ubuntu")
client, err := ssh.New(config)
// client, err := ssh.NewClient("localhost", "22", "root", "ubuntu")
if err != nil {
panic(err)
}

View File

@ -1,6 +1,8 @@
package ssh
import (
"os"
"path"
"time"
)
@ -29,48 +31,51 @@ type Config struct {
}
var DefaultConfig = &Config{
User: "root",
Host: "localhost",
Port: 22,
KeyFiles: []string{"~/.ssh/id_rsa"},
User: "root",
// KeyFiles: []string{path.Join(os.Getenv("HOME"), "/.ssh/id_rsa")},
}
var Default = DefaultConfig
//
func (c *Client) WithUser(user string) *Client {
func WithUser(user string) *Config {
return Default.WithUser(user)
}
func (c *Config) WithUser(user string) *Config {
if user == "" {
user = "root"
}
c.User = user
return c
}
func WithHost(host string) *Config {
return Default.WithHost(host)
}
//
func (c *Client) WithHost(host string) *Client {
func (c *Config) WithHost(host string) *Config {
if host == "" {
host = "localhost"
}
c.Host = host
return c
}
func (c *Client) WithPassword(password string) *Client {
func WithPassword(password string) *Config {
return Default.WithPassword(password)
}
func (c *Config) WithPassword(password string) *Config {
c.Password = password
return c
}
//
func (c *Client) SetKeys(keyfiles []string) *Client {
if keyfiles == nil {
return c
func WithKey(keyfile, passphrase string) *Config {
return Default.WithKey(keyfile, passphrase)
}
t := make([]string, len(keyfiles))
copy(t, keyfiles)
c.KeyFiles = t
return c
}
//
func (c *Client) WithKey(keyfile string) *Client {
func (c *Config) WithKey(keyfile, passphrase string) *Config {
if keyfile == "" {
keyfile = "~/.ssh/id_rsa"
if home := os.Getenv("HOME"); home != "" {
keyfile = path.Join(home, "/.ssh/id_rsa")
}
}
for _, s := range c.KeyFiles {
if s == keyfile {
@ -80,3 +85,14 @@ func (c *Client) WithKey(keyfile string) *Client {
c.KeyFiles = append(c.KeyFiles, keyfile)
return c
}
//
func (c *Config) SetKeys(keyfiles []string) *Config {
if keyfiles == nil {
return c
}
t := make([]string, len(keyfiles))
copy(t, keyfiles)
c.KeyFiles = t
return c
}

View File

@ -6,7 +6,7 @@ import (
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
client, err := ssh.NewClient("localhost", "22", "root", "ubuntu")
if err != nil {
panic(err)
}

View File

@ -6,7 +6,7 @@ import (
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
client, err := ssh.NewClient("localhost", "22", "root", "ubuntu")
if err != nil {
panic(err)
}

109
sftp.go
View File

@ -1,8 +1,8 @@
package ssh
import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
@ -22,9 +22,10 @@ func (c *Client) Upload(local string, remote string) (err error) {
info, err := os.Stat(local)
if err != nil {
return errors.New("本地文件不存在或格式错误 Upload(\"" + local + "\") 跳过上传")
return errors.New("sftp: 跳过上传 Upload(\"" + local + "\") ,本地文件不存在或格式错误!")
}
if info.IsDir() {
log.Println("sftp: UploadDir", local)
return c.UploadDir(local, remote)
}
return c.UploadFile(local, remote)
@ -33,7 +34,7 @@ func (c *Client) Upload(local string, remote string) (err error) {
// Download 下载sftp远程文件 remote 到本地 local like rsync
func (c *Client) Download(remote string, local string) (err error) {
if c.IsNotExist(strings.TrimSuffix(remote, "/")) {
return errors.New("文件不存在,跳过文件下载 \"" + remote + "\" ")
return errors.New("sftp: 远程文件不存在,跳过文件下载 \"" + remote + "\" ")
}
if c.IsDir(remote) {
// return errors.New("检测到远程是文件不是目录 \"" + remote + "\" 跳过下载")
@ -48,7 +49,7 @@ func (c *Client) Download(remote string, local string) (err error) {
func (c *Client) downloadFile(remoteFile, local string) error {
// remoteFile = strings.TrimSuffix(remoteFile, "/")
if !c.IsFile(remoteFile) {
return errors.New("文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")")
return errors.New("sftp: 文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")")
}
var localFile string
if local[len(local)-1] == '/' {
@ -56,9 +57,22 @@ func (c *Client) downloadFile(remoteFile, local string) error {
} else {
localFile = local
}
if c.Size(remoteFile) > 1000 {
rsum := c.Md5File(remoteFile)
ioutil.WriteFile(localFile+".md5", []byte(rsum), 755)
if FileExist(localFile) {
// 1. 检测远程是否存在
if rsum != "" {
lsum, _ := Md5File(localFile)
if lsum == rsum {
log.Println("sftp: 文件与本地一致,跳过上传!", localFile)
return nil
}
}
}
}
if err := os.MkdirAll(filepath.Dir(localFile), os.ModePerm); err != nil {
// fmt.Println(err)
// log.Println(err)
return err
}
@ -87,7 +101,7 @@ func (c *Client) downloadDir(remote, local string) error {
var localDir, remoteDir string
if !c.IsDir(remote) {
return errors.New("目录不存在或不是目录, 跳过 downloadDir(" + remote + ")")
return errors.New("sftp: 目录不存在或不是目录, 跳过 downloadDir(" + remote + ")")
}
remoteDir = remote
if remote[len(remote)-1] == '/' {
@ -100,7 +114,7 @@ func (c *Client) downloadDir(remote, local string) error {
for walker.Step() {
if err := walker.Err(); err != nil {
fmt.Fprintln(os.Stderr, err)
log.Println(err)
continue
}
@ -152,9 +166,10 @@ func (c *Client) downloadDir(remote, local string) error {
//UploadFile 上传本地文件 localFile 到sftp远程目录 remote
func (c *Client) UploadFile(localFile, remote string) error {
// localFile = strings.TrimSuffix(localFile, "/")
// localFile = filepath.ToSlash(localFile)
info, err := os.Stat(localFile)
if err != nil || info.IsDir() {
return errors.New("本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传")
return errors.New("sftp: 本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传")
}
l, err := os.Open(localFile)
@ -165,15 +180,28 @@ func (c *Client) UploadFile(localFile, remote string) error {
var remoteFile, remoteDir string
if remote[len(remote)-1] == '/' {
remoteFile = filepath.Join(remote, filepath.Base(localFile))
remoteFile = filepath.ToSlash(filepath.Join(remote, filepath.Base(localFile)))
remoteDir = remote
} else {
remoteFile = remote
remoteDir = filepath.Dir(remoteFile)
remoteDir = filepath.ToSlash(filepath.Dir(remoteFile))
}
log.Println("sftp: UploadFile", localFile, remoteFile)
if info.Size() > 1000 {
// 1. 检测远程是否存在
rsum := c.Md5File(remoteFile)
if rsum != "" {
lsum, _ := Md5File(localFile)
if lsum == rsum {
log.Println("sftp: 文件与本地一致,跳过上传!", localFile)
return nil
}
}
}
// 目录不存在,则创建 remoteDir
if _, err := c.SFTPClient.Stat(remoteDir); err != nil {
log.Println("sftp: Mkdir all", remoteDir)
c.MkdirAll(remoteDir)
}
@ -194,16 +222,17 @@ func (c *Client) UploadDir(localDir string, remoteDir string) (err error) {
// }
// }()
// 本地输入检测,必须是目录
// localDir = filepath.ToSlash(localDir)
info, err := os.Stat(localDir)
if err != nil || !info.IsDir() {
return errors.New("本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传")
return errors.New("sftp: 本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传")
}
// 模仿 rsync localDir不以'/'结尾,则创建尾目录
if localDir[len(localDir)-1] != '/' {
remoteDir = filepath.Join(remoteDir, filepath.Base(localDir))
remoteDir = filepath.ToSlash(filepath.Join(remoteDir, filepath.Base(localDir)))
}
// fmt.Println("remoteDir", remoteDir)
log.Println("sftp: UploadDir", localDir, remoteDir)
rootDst := strings.TrimSuffix(remoteDir, "/")
if c.IsFile(rootDst) {
@ -230,14 +259,18 @@ func (c *Client) UploadDir(localDir string, remoteDir string) (err error) {
// it should exist and we might not even own it
if finalDst == remoteDir {
return nil
fmt.Println("skip", remoteDir, "--->", finalDst)
log.Println("sftp: ", remoteDir, "--->", finalDst)
}
if info.IsDir() {
c.MkdirAll(finalDst)
err := c.MkdirAll(finalDst)
if err != nil {
log.Println("sftp: MkdirAll", err)
}
// log.Println("MkdirAll", finalDst)
// err = c.SFTPClient.Mkdir(finalDst)
// fmt.Println(err)
// log.Println(err)
// if err := c.SFTPClient.Mkdir(finalDst); err != nil {
// // Do not consider it an error if the directory existed
// remoteFi, fiErr := c.SFTPClient.Lstat(finalDst)
@ -295,7 +328,7 @@ func (c *Client) RemoveFile(remoteFile string) error {
func (c *Client) RemoveDir(remoteDir string) error {
remoteFiles, err := c.SFTPClient.ReadDir(remoteDir)
if err != nil {
log.Printf("remove remote dir: %s err: %v\n", remoteDir, err)
log.Printf("sftp: remove remote dir: %s err: %v\n", remoteDir, err)
return err
}
for _, file := range remoteFiles {
@ -307,7 +340,7 @@ func (c *Client) RemoveDir(remoteDir string) error {
}
}
c.SFTPClient.RemoveDirectory(remoteDir) //must empty dir to remove
log.Printf("remove remote dir: %s ok\n", remoteDir)
log.Printf("sftp: remove remote dir: %s ok\n", remoteDir)
return nil
}
@ -319,9 +352,11 @@ func (c *Client) RemoveAll(remoteDir string) error {
//MkdirAll 创建目录,递归
func (c *Client) MkdirAll(dirpath string) error {
parentDir := filepath.Dir(dirpath)
parentDir := filepath.ToSlash(filepath.Dir(dirpath))
_, err := c.SFTPClient.Stat(parentDir)
if err != nil {
// log.Println(err)
if err.Error() == "file does not exist" {
err := c.MkdirAll(parentDir)
if err != nil {
@ -331,23 +366,13 @@ func (c *Client) MkdirAll(dirpath string) error {
return err
}
}
err = c.SFTPClient.Mkdir(dirpath)
err = c.SFTPClient.Mkdir(filepath.ToSlash(dirpath))
if err != nil {
return err
}
return nil
}
func MkdirAll(path string) error {
// 检测文件夹是否存在 若不存在 创建文件夹
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return os.MkdirAll(path, os.ModePerm)
}
}
return nil
}
func (c *Client) Mkdir(path string, fi os.FileInfo) error {
log.Printf("[DEBUG] sftp: creating dir %s", path)
@ -376,6 +401,15 @@ func (c *Client) IsDir(path string) bool {
return false
}
//Size 获取文件大小
func (c *Client) Size(path string) int64 {
info, err := c.SFTPClient.Stat(path)
if err != nil {
return 0
}
return info.Size()
}
//IsFile 检查远程是否是个文件
func (c *Client) IsFile(path string) bool {
info, err := c.SFTPClient.Stat(path)
@ -397,3 +431,16 @@ func (c *Client) IsExist(path string) bool {
_, err := c.SFTPClient.Stat(path)
return err == nil
}
//Md5File 检查远程是文件是否存在
func (c *Client) Md5File(path string) string {
if c.IsNotExist(path) {
return ""
}
b, err := c.Run("md5sum " + path)
if err != nil {
return ""
}
return string(bytes.Split(b, []byte{' '})[0])
}

View File

@ -12,7 +12,7 @@ func GetClient() *Client {
err error
)
once.Do(func() {
c, err = NewClient("root", "localhost", "22", "ubuntu")
c, err = NewClient("localhost", "22", "root", "ubuntu")
})
if err != nil {
panic(err)

2
ssh.go
View File

@ -39,7 +39,7 @@ func (c *Client) RunScript(scriptPath string) ([]byte, error) {
defer session.Close()
// 1. 上传 script
remotePath := fmt.Sprintf("/tmp/%s", filepath.Base(scriptPath))
remotePath := fmt.Sprintf("/tmp/script/%s", filepath.Base(scriptPath))
if err := c.UploadFile(scriptPath, remotePath); err != nil {
return nil, err
}

47
util.go Normal file
View File

@ -0,0 +1,47 @@
package ssh
import (
"bufio"
"crypto"
"encoding/hex"
"io"
"os"
)
func FileExist(file string) bool {
if _, err := os.Stat(file); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
func MkdirAll(path string) error {
// 检测文件夹是否存在 若不存在 创建文件夹
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return os.MkdirAll(path, os.ModePerm)
}
}
return nil
}
//Md5File 计算md5
func Md5File(filename string) (string, error) {
f, err := os.Open(filename)
if err != nil {
return "", err
}
defer f.Close()
r := bufio.NewReader(f)
hash := crypto.MD5.New()
_, err = io.Copy(hash, r)
if err != nil {
return "", err
}
out := hex.EncodeToString(hash.Sum(nil))
return out, nil
}