golang lib ssh
This commit is contained in:
parent
22f8a4e050
commit
a2d60ca1f3
|
@ -0,0 +1,95 @@
|
|||
|
||||
## 项目简介
|
||||
本项目是基于golang标准库 ssh 和 sftp 开发
|
||||
|
||||
本项目是对标准库进行一个简单的高层封装,使得可以在在 Windows Linux Mac 上非常容易的执行 ssh 命令,
|
||||
以及文件,文件夹的上传,下载等操作.
|
||||
|
||||
文件上传下载模仿rsync方式: 只和源有关.
|
||||
// rsync -av src/ dst ./src/* --> /root/dst/*
|
||||
// rsync -av src/ dst/ ./src/* --> /root/dst/*
|
||||
// rsync -av src dst ./src/* --> /root/dst/src/*
|
||||
// rsync -av src dst/ ./src/* --> /root/dst/src/*
|
||||
|
||||
## Example
|
||||
|
||||
### 在远程执行ssh命令
|
||||
```go
|
||||
package main
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pytool/ssh"
|
||||
)
|
||||
func main() {
|
||||
|
||||
c, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
output, err := c.Exec("uptime")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Uptime: %s\n", output)
|
||||
}
|
||||
|
||||
```
|
||||
### 文件下载
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/pytool/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer client.Close()
|
||||
var remotedir = "/root/test/"
|
||||
// download dir
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/"
|
||||
client.Download(remotedir, local)
|
||||
|
||||
// upload file
|
||||
var remotefile = "/root/test/file"
|
||||
|
||||
client.Download(remotefile, local)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### 文件上传
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/pytool/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer client.Close()
|
||||
var remotedir = "/root/test/"
|
||||
// upload dir
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
|
||||
client.Upload(local, remotedir)
|
||||
|
||||
// upload file
|
||||
local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/file"
|
||||
client.Upload(local, remotedir)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
// An implementation of ssh.KeyboardInteractiveChallenge that simply sends
|
||||
// back the password for all questions. The questions are logged.
|
||||
func passwordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallenge {
|
||||
return func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
||||
// log.Printf("Keyboard interactive challenge: ")
|
||||
// log.Printf("-- User: %s", user)
|
||||
// log.Printf("-- Instructions: %s", instruction)
|
||||
// for i, question := range questions {
|
||||
// log.Printf("-- Question %d: %s", i+1, question)
|
||||
// }
|
||||
|
||||
// Just send the password back for all questions
|
||||
answers := make([]string, len(questions))
|
||||
for i := range answers {
|
||||
answers[i] = password
|
||||
}
|
||||
|
||||
return answers, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeyboardPassword Generate a password-auth'd ssh ClientConfig
|
||||
func WithKeyboardPassword(password string) (ssh.AuthMethod, error) {
|
||||
return ssh.KeyboardInteractive(passwordKeyboardInteractive(password)), nil
|
||||
}
|
||||
|
||||
// WithPassword Generate a password-auth'd ssh ClientConfig
|
||||
func WithPassword(password string) (ssh.AuthMethod, error) {
|
||||
return ssh.Password(password), nil
|
||||
}
|
||||
|
||||
// WithAgent use already authed user
|
||||
func WithAgent() (ssh.AuthMethod, error) {
|
||||
sock := os.Getenv("SSH_AUTH_SOCK")
|
||||
if sock != "" {
|
||||
// fmt.Println(errors.New("Agent Disabled"))
|
||||
return nil, errors.New("Agent Disabled")
|
||||
}
|
||||
socks, err := net.Dial("unix", sock)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
// 1. 返回Signers函数的结果
|
||||
agent := agent.NewClient(socks)
|
||||
signers, err := agent.Signers()
|
||||
return ssh.PublicKeys(signers...), nil
|
||||
// 2. 返回Signers函数
|
||||
// getSigners := agent.NewClient(socks).Signers
|
||||
// return ssh.PublicKeysCallback(getSigners), nil
|
||||
|
||||
// 3.简写方式
|
||||
// if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
|
||||
// return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
|
||||
// }
|
||||
// return nil
|
||||
}
|
||||
|
||||
// WithPrivateKeys 设置多个 ~/.ssh/id_rsa
|
||||
func WithPrivateKeys(keyFiles []string, password string) (ssh.AuthMethod, error) {
|
||||
var signers []ssh.Signer
|
||||
|
||||
for _, key := range keyFiles {
|
||||
|
||||
buffer, err := ioutil.ReadFile(key)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
// return
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKeyWithPassphrase([]byte(buffer), []byte(password))
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
} else {
|
||||
signers = append(signers, signer)
|
||||
}
|
||||
}
|
||||
if signers == nil {
|
||||
return nil, errors.New("WithPrivateKeys: no keyfiles input")
|
||||
}
|
||||
return ssh.PublicKeys(signers...), nil
|
||||
}
|
||||
|
||||
// WithPrivateKey 自动监测是否带有密码
|
||||
func WithPrivateKey(keyfile string, password string) (ssh.AuthMethod, error) {
|
||||
pemBytes, err := ioutil.ReadFile(keyfile)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signer ssh.Signer
|
||||
signer, err = ssh.ParsePrivateKey(pemBytes)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "cannot decode encrypted private keys") {
|
||||
if signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password)); err == nil {
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// WithPrivateKeyString 直接通过字符串
|
||||
func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) {
|
||||
var signer ssh.Signer
|
||||
var err error
|
||||
if password == "" {
|
||||
signer, err = ssh.ParsePrivateKey([]byte(key))
|
||||
} else {
|
||||
signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(key), []byte(password))
|
||||
}
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
// WithPrivateKeyTerminal 通过终端读取带密码的 PublicKey
|
||||
func WithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) {
|
||||
// fmt.Fprintf(os.Stderr, "This SSH key is encrypted. Please enter passphrase for key '%s':", priv.path)
|
||||
passphrase, err := terminal.ReadPassword(int(syscall.Stdin))
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr)
|
||||
|
||||
pemBytes, err := ioutil.ReadFile(keyfile)
|
||||
if err != nil {
|
||||
|
||||
println(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, passphrase)
|
||||
if err != nil {
|
||||
|
||||
fmt.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const DefaultTimeout = 30 * time.Second
|
||||
|
||||
type Client struct {
|
||||
*Config
|
||||
SSHClient *ssh.Client
|
||||
SFTPClient *sftp.Client
|
||||
}
|
||||
|
||||
// NewClient 根据配置
|
||||
func NewClient(user, host, port, password string) (client *Client, err error) {
|
||||
p, err := strconv.Atoi(port)
|
||||
if err == nil || p == 0 {
|
||||
p = 22
|
||||
}
|
||||
if user == "" {
|
||||
user = "root"
|
||||
}
|
||||
var config = &Config{
|
||||
User: user,
|
||||
Host: host,
|
||||
Port: p,
|
||||
Password: password,
|
||||
// KeyFiles: []string{"~/.ssh/id_rsa"},
|
||||
}
|
||||
return New(config)
|
||||
}
|
||||
|
||||
// New 创建SSH client
|
||||
func New(config *Config) (client *Client, err error) {
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: config.User,
|
||||
Timeout: DefaultTimeout,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
// 2. 密码方式
|
||||
if config.Password != "" {
|
||||
clientConfig.Auth = append(clientConfig.Auth, ssh.Password(config.Password))
|
||||
}
|
||||
|
||||
// 3. privite key file
|
||||
if len(config.KeyFiles) == 0 {
|
||||
keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
|
||||
if auth, err := WithPrivateKey(keyPath, config.Password); err != nil {
|
||||
clientConfig.Auth = append(clientConfig.Auth, auth)
|
||||
}
|
||||
} else {
|
||||
if auth, err := WithPrivateKeys(config.KeyFiles, config.Password); err != nil {
|
||||
clientConfig.Auth = append(clientConfig.Auth, auth)
|
||||
}
|
||||
}
|
||||
// 1. agent
|
||||
if auth, err := WithAgent(); err != nil {
|
||||
clientConfig.Auth = append(clientConfig.Auth, auth)
|
||||
}
|
||||
if config.Port == 0 {
|
||||
config.Port = 22
|
||||
}
|
||||
// hostPort := config.Host + ":" + strconv.Itoa(config.Port)
|
||||
sshClient, err := ssh.Dial("tcp", net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), clientConfig)
|
||||
|
||||
if err != nil {
|
||||
return client, errors.New("Failed to dial ssh: " + err.Error())
|
||||
}
|
||||
|
||||
// create sftp client
|
||||
var sftpClient *sftp.Client
|
||||
if sftpClient, err = sftp.NewClient(sshClient); err != nil {
|
||||
return client, errors.New("Failed to conn sftp: " + err.Error())
|
||||
}
|
||||
|
||||
return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil
|
||||
}
|
||||
|
||||
// Execute cmd on the remote host and return stderr and stdout
|
||||
func (c *Client) Exec(cmd string) ([]byte, error) {
|
||||
session, err := c.SSHClient.NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
return session.CombinedOutput(cmd)
|
||||
}
|
||||
|
||||
// Close the underlying SSH connection
|
||||
func (c *Client) Close() {
|
||||
c.SFTPClient.Close()
|
||||
c.SSHClient.Close()
|
||||
}
|
||||
|
||||
func addPortToHost(host string) string {
|
||||
_, _, err := net.SplitHostPort(host)
|
||||
|
||||
// We got an error so blindly try to add a port number
|
||||
if err != nil {
|
||||
return net.JoinHostPort(host, "22")
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pytool/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
output, err := client.Exec("uptime")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Uptime: %s\n", output)
|
||||
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
User string
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
KeyFiles []string
|
||||
|
||||
// DisableAgentForwarding, if true, will not forward the SSH agent.
|
||||
DisableAgentForwarding bool
|
||||
|
||||
// HandshakeTimeout limits the amount of time we'll wait to handshake before
|
||||
// saying the connection failed.
|
||||
HandshakeTimeout time.Duration
|
||||
|
||||
// KeepAliveInterval sets how often we send a channel request to the
|
||||
// server. A value < 0 disables.
|
||||
KeepAliveInterval time.Duration
|
||||
|
||||
// Timeout is how long to wait for a read or write to succeed.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
var DefaultConfig = &Config{
|
||||
User: "root",
|
||||
Port: 22,
|
||||
KeyFiles: []string{"~/.ssh/id_rsa"},
|
||||
}
|
||||
|
||||
//
|
||||
func (c *Client) WithUser(user string) *Client {
|
||||
if user == "" {
|
||||
user = "root"
|
||||
}
|
||||
c.User = user
|
||||
return c
|
||||
}
|
||||
|
||||
//
|
||||
func (c *Client) WithHost(host string) *Client {
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
c.Host = host
|
||||
return c
|
||||
}
|
||||
func (c *Client) WithPassword(password string) *Client {
|
||||
c.Password = password
|
||||
return c
|
||||
}
|
||||
|
||||
//
|
||||
func (c *Client) SetKeys(keyfiles []string) *Client {
|
||||
if keyfiles == nil {
|
||||
return c
|
||||
}
|
||||
t := make([]string, len(keyfiles))
|
||||
copy(t, keyfiles)
|
||||
c.KeyFiles = t
|
||||
return c
|
||||
}
|
||||
|
||||
//
|
||||
func (c *Client) WithKey(keyfile string) *Client {
|
||||
if keyfile == "" {
|
||||
keyfile = "~/.ssh/id_rsa"
|
||||
}
|
||||
for _, s := range c.KeyFiles {
|
||||
if s == keyfile {
|
||||
return c
|
||||
}
|
||||
}
|
||||
c.KeyFiles = append(c.KeyFiles, keyfile)
|
||||
return c
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/pytool/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer client.Close()
|
||||
var remotedir = "/root/test/"
|
||||
// download dir
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/"
|
||||
client.Download(remotedir, local)
|
||||
|
||||
// upload file
|
||||
var remotefile = "/root/test/file"
|
||||
|
||||
client.Download(remotefile, local)
|
||||
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/pytool/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer client.Close()
|
||||
var remotedir = "/root/test/"
|
||||
// upload dir
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
|
||||
client.Upload(local, remotedir)
|
||||
// upload file
|
||||
local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/file"
|
||||
client.Upload(local, remotedir)
|
||||
|
||||
}
|
|
@ -0,0 +1,399 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Upload 上传本地文件 local 到sftp远程目录 remote like rsync
|
||||
// rsync -av src/ dst ./src/* --> /root/dst/*
|
||||
// rsync -av src/ dst/ ./src/* --> /root/dst/*
|
||||
// rsync -av src dst ./src/* --> /root/dst/src/*
|
||||
// rsync -av src dst/ ./src/* --> /root/dst/src/*
|
||||
func (c *Client) Upload(local string, remote string) (err error) {
|
||||
// var localDir, localFile, remoteDir, remoteFile string
|
||||
|
||||
info, err := os.Stat(local)
|
||||
if err != nil {
|
||||
return errors.New("本地文件不存在或格式错误 Upload(\"" + local + "\") 跳过上传")
|
||||
}
|
||||
if info.IsDir() {
|
||||
return c.UploadDir(local, remote)
|
||||
}
|
||||
return c.UploadFile(local, remote)
|
||||
}
|
||||
|
||||
// Download 下载sftp远程文件 remote 到本地 local like rsync
|
||||
func (c *Client) Download(remote string, local string) (err error) {
|
||||
if c.IsNotExist(strings.TrimSuffix(remote, "/")) {
|
||||
return errors.New("文件不存在,跳过文件下载 \"" + remote + "\" ")
|
||||
}
|
||||
if c.IsDir(remote) {
|
||||
// return errors.New("检测到远程是文件不是目录 \"" + remote + "\" 跳过下载")
|
||||
return c.downloadDir(remote, local)
|
||||
|
||||
}
|
||||
return c.downloadFile(remote, local)
|
||||
|
||||
}
|
||||
|
||||
// downloadFile a file from the remote server like cp
|
||||
func (c *Client) downloadFile(remoteFile, local string) error {
|
||||
// remoteFile = strings.TrimSuffix(remoteFile, "/")
|
||||
if !c.IsFile(remoteFile) {
|
||||
return errors.New("文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")")
|
||||
}
|
||||
var localFile string
|
||||
if local[len(local)-1] == '/' {
|
||||
localFile = filepath.Join(local, filepath.Base(remoteFile))
|
||||
} else {
|
||||
localFile = local
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(localFile), os.ModePerm); err != nil {
|
||||
// fmt.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := c.SFTPClient.Open(remoteFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
l, err := os.Create(localFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
_, err = io.Copy(l, r)
|
||||
return err
|
||||
}
|
||||
|
||||
// downloadDir from remote dir to local dir like rsync
|
||||
// rsync -av src/ dst ./src/* --> /root/dst/*
|
||||
// rsync -av src/ dst/ ./src/* --> /root/dst/*
|
||||
// rsync -av src dst ./src/* --> /root/dst/src/*
|
||||
// rsync -av src dst/ ./src/* --> /root/dst/src/*
|
||||
func (c *Client) downloadDir(remote, local string) error {
|
||||
var localDir, remoteDir string
|
||||
|
||||
if !c.IsDir(remote) {
|
||||
return errors.New("目录不存在或不是目录, 跳过 downloadDir(" + remote + ")")
|
||||
}
|
||||
remoteDir = remote
|
||||
if remote[len(remote)-1] == '/' {
|
||||
localDir = local
|
||||
} else {
|
||||
localDir = path.Join(local, path.Base(remote))
|
||||
}
|
||||
|
||||
walker := c.SFTPClient.Walk(remoteDir)
|
||||
|
||||
for walker.Step() {
|
||||
if err := walker.Err(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
info := walker.Stat()
|
||||
|
||||
relPath, err := filepath.Rel(remoteDir, walker.Path())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
localPath := filepath.ToSlash(filepath.Join(localDir, relPath))
|
||||
|
||||
// if we have something at the download path delete it if it is a directory
|
||||
// and the remote is a file and vice a versa
|
||||
localInfo, err := os.Stat(localPath)
|
||||
if os.IsExist(err) {
|
||||
if localInfo.IsDir() {
|
||||
if info.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
err = os.RemoveAll(localPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if info.IsDir() {
|
||||
err = os.Remove(localPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
err = os.MkdirAll(localPath, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
c.downloadFile(walker.Path(), localPath)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//UploadFile 上传本地文件 localFile 到sftp远程目录 remote
|
||||
func (c *Client) UploadFile(localFile, remote string) error {
|
||||
// localFile = strings.TrimSuffix(localFile, "/")
|
||||
info, err := os.Stat(localFile)
|
||||
if err != nil || info.IsDir() {
|
||||
return errors.New("本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传")
|
||||
}
|
||||
|
||||
l, err := os.Open(localFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
var remoteFile, remoteDir string
|
||||
if remote[len(remote)-1] == '/' {
|
||||
remoteFile = filepath.Join(remote, filepath.Base(localFile))
|
||||
remoteDir = remote
|
||||
} else {
|
||||
remoteFile = remote
|
||||
remoteDir = filepath.Dir(remoteFile)
|
||||
}
|
||||
|
||||
// 目录不存在,则创建 remoteDir
|
||||
if _, err := c.SFTPClient.Stat(remoteDir); err != nil {
|
||||
c.MkdirAll(remoteDir)
|
||||
}
|
||||
|
||||
r, err := c.SFTPClient.Create(remoteFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(r, l)
|
||||
return err
|
||||
}
|
||||
|
||||
// UploadDir files without checking diff status
|
||||
func (c *Client) UploadDir(localDir string, remoteDir string) (err error) {
|
||||
// defer func() {
|
||||
// if err != nil {
|
||||
// err = errors.New("UploadDir " + err.Error())
|
||||
// }
|
||||
// }()
|
||||
// 本地输入检测,必须是目录
|
||||
info, err := os.Stat(localDir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return errors.New("本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传")
|
||||
}
|
||||
|
||||
// 模仿 rsync localDir不以'/'结尾,则创建尾目录
|
||||
if localDir[len(localDir)-1] != '/' {
|
||||
remoteDir = filepath.Join(remoteDir, filepath.Base(localDir))
|
||||
}
|
||||
// fmt.Println("remoteDir", remoteDir)
|
||||
|
||||
rootDst := strings.TrimSuffix(remoteDir, "/")
|
||||
if c.IsFile(rootDst) {
|
||||
c.SFTPClient.Remove(rootDst)
|
||||
}
|
||||
|
||||
walkFunc := func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Calculate the final destination using the
|
||||
// base source and root destination
|
||||
relSrc, err := filepath.Rel(localDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finalDst := filepath.Join(rootDst, relSrc)
|
||||
|
||||
// In Windows, Join uses backslashes which we don't want to get
|
||||
// to the sftp server
|
||||
finalDst = filepath.ToSlash(finalDst)
|
||||
|
||||
// Skip the creation of the target destination directory since
|
||||
// it should exist and we might not even own it
|
||||
if finalDst == remoteDir {
|
||||
return nil
|
||||
fmt.Println("skip", remoteDir, "--->", finalDst)
|
||||
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
c.MkdirAll(finalDst)
|
||||
// err = c.SFTPClient.Mkdir(finalDst)
|
||||
// fmt.Println(err)
|
||||
// if err := c.SFTPClient.Mkdir(finalDst); err != nil {
|
||||
// // Do not consider it an error if the directory existed
|
||||
// remoteFi, fiErr := c.SFTPClient.Lstat(finalDst)
|
||||
// if fiErr != nil || !remoteFi.IsDir() {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
// return err
|
||||
} else {
|
||||
// f, err := os.Open(path)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer f.Close()
|
||||
return c.UploadFile(path, finalDst)
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
return filepath.Walk(localDir, walkFunc)
|
||||
}
|
||||
|
||||
// Remove a file from the remote server
|
||||
func (c *Client) Remove(path string) error {
|
||||
return c.SFTPClient.Remove(path)
|
||||
}
|
||||
|
||||
// RemoveDirectory Remove a directory from the remote server
|
||||
func (c *Client) RemoveDirectory(path string) error {
|
||||
return c.SFTPClient.RemoveDirectory(path)
|
||||
}
|
||||
|
||||
// ReadAll Read a remote file and return the contents.
|
||||
func (c *Client) ReadAll(filepath string) ([]byte, error) {
|
||||
file, err := c.SFTPClient.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
return ioutil.ReadAll(file)
|
||||
}
|
||||
|
||||
//FileExist 文件是否存在
|
||||
func (c *Client) FileExist(filepath string) (bool, error) {
|
||||
if _, err := c.SFTPClient.Stat(filepath); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *Client) RemoveFile(remoteFile string) error {
|
||||
return c.SFTPClient.Remove(remoteFile)
|
||||
}
|
||||
func (c *Client) RemoveDir(remoteDir string) error {
|
||||
remoteFiles, err := c.SFTPClient.ReadDir(remoteDir)
|
||||
if err != nil {
|
||||
log.Printf("remove remote dir: %s err: %v\n", remoteDir, err)
|
||||
return err
|
||||
}
|
||||
for _, file := range remoteFiles {
|
||||
subRemovePath := path.Join(remoteDir, file.Name())
|
||||
if file.IsDir() {
|
||||
c.RemoveDir(subRemovePath)
|
||||
} else {
|
||||
c.RemoveFile(subRemovePath)
|
||||
}
|
||||
}
|
||||
c.SFTPClient.RemoveDirectory(remoteDir) //must empty dir to remove
|
||||
log.Printf("remove remote dir: %s ok\n", remoteDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
//RemoveAll 递归删除目录,文件
|
||||
func (c *Client) RemoveAll(remoteDir string) error {
|
||||
c.RemoveDir(remoteDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
//MkdirAll 创建目录,递归
|
||||
func (c *Client) MkdirAll(dirpath string) error {
|
||||
parentDir := filepath.Dir(dirpath)
|
||||
_, err := c.SFTPClient.Stat(parentDir)
|
||||
if err != nil {
|
||||
if err.Error() == "file does not exist" {
|
||||
err := c.MkdirAll(parentDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = c.SFTPClient.Mkdir(dirpath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func MkdirAll(path string) error {
|
||||
// 检测文件夹是否存在 若不存在 创建文件夹
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return os.MkdirAll(path, os.ModePerm)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Mkdir(path string, fi os.FileInfo) error {
|
||||
log.Printf("[DEBUG] sftp: creating dir %s", path)
|
||||
|
||||
if err := c.SFTPClient.Mkdir(path); err != nil {
|
||||
// Do not consider it an error if the directory existed
|
||||
remoteFi, fiErr := c.SFTPClient.Lstat(path)
|
||||
if fiErr != nil || !remoteFi.IsDir() {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
mode := fi.Mode().Perm()
|
||||
if err := c.SFTPClient.Chmod(path, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//IsDir 检查远程是否是个目录
|
||||
func (c *Client) IsDir(path string) bool {
|
||||
// 检查远程是文件还是目录
|
||||
info, err := c.SFTPClient.Stat(path)
|
||||
if err == nil && info.IsDir() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//IsFile 检查远程是否是个文件
|
||||
func (c *Client) IsFile(path string) bool {
|
||||
info, err := c.SFTPClient.Stat(path)
|
||||
if err == nil && !info.IsDir() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
//IsNotExist 检查远程是文件是否不存在
|
||||
func (c *Client) IsNotExist(path string) bool {
|
||||
_, err := c.SFTPClient.Stat(path)
|
||||
return err != nil
|
||||
}
|
||||
|
||||
//IsExist 检查远程是文件是否存在
|
||||
func (c *Client) IsExist(path string) bool {
|
||||
|
||||
_, err := c.SFTPClient.Stat(path)
|
||||
return err == nil
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClient_IsCheck(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
var remotes = []string{
|
||||
"/root/test/notExist",
|
||||
"/root/test/notExist/",
|
||||
"/root/test/file",
|
||||
"/root/test/file/", // 不存在
|
||||
"/root/test/dir",
|
||||
"/root/test/dir/",
|
||||
}
|
||||
|
||||
// /root/test/file 存在
|
||||
// /root/test/file/ 不存在
|
||||
// /root/test/dir 存在
|
||||
// /root/test/dir/ 存在
|
||||
for _, v := range remotes {
|
||||
is := c.IsExist(v)
|
||||
if is {
|
||||
println(v, "\t存在")
|
||||
} else {
|
||||
println(v, "\t不存在")
|
||||
}
|
||||
}
|
||||
|
||||
// /root/test/file 不是一个目录
|
||||
// /root/test/file/ 不是一个目录
|
||||
// /root/test/dir 是一个目录
|
||||
// /root/test/dir/ 是一个目录
|
||||
println()
|
||||
for _, v := range remotes {
|
||||
isdir := c.IsDir(v)
|
||||
if isdir {
|
||||
println(v, "\t是一个目录")
|
||||
} else {
|
||||
println(v, "\t不是一个目录")
|
||||
}
|
||||
}
|
||||
|
||||
// /root/test/file 是一个文件
|
||||
// /root/test/file/ 不是一个文件
|
||||
// /root/test/dir 不是一个文件
|
||||
// /root/test/dir/ 不是一个文件
|
||||
println()
|
||||
for _, v := range remotes {
|
||||
isfile := c.IsFile(v)
|
||||
if isfile {
|
||||
println(v, "\t是一个文件")
|
||||
} else {
|
||||
println(v, "\t不是一个文件")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func GetClient() *Client {
|
||||
var (
|
||||
once = sync.Once{}
|
||||
c = &Client{}
|
||||
err error
|
||||
)
|
||||
once.Do(func() {
|
||||
c, err = NewClient("root", "localhost", "22", "ubuntu")
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// func TestClient_RemoveAll(t *testing.T) {
|
||||
// c := GetClient()
|
||||
// defer c.Close()
|
||||
// var remotedir = "/root/test/"
|
||||
// fmt.Println(c.RemoveAll(remotedir))
|
||||
// }
|
||||
func TestClient_Init(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
|
||||
|
||||
var remotedir = "/root/test/"
|
||||
c.RemoveAll("/root/upload/")
|
||||
|
||||
err := c.Upload(local, remotedir)
|
||||
if err != nil {
|
||||
println("[Upload]", err.Error())
|
||||
}
|
||||
}
|
||||
func TestClient_Upload(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
|
||||
var uploads = map[string][]string{
|
||||
local + "null/": []string{"/root/upload/test/null/1", "/root/upload/test/null/2/"},
|
||||
local + "null/": []string{"/root/upload/test/null/3", "/root/upload/test/null/4/"},
|
||||
local + "file": []string{"/root/upload/test/file/1", "/root/upload/test/file/2/"},
|
||||
local + "file/": []string{"/root/upload/test/file/3", "/root/upload/test/file/4/"},
|
||||
local + "dir": []string{"/root/upload/test/dir/1", "/root/upload/test/dir/2/"},
|
||||
local + "dir/": []string{"/root/upload/test/dir/3", "/root/upload/test/dir/4/"},
|
||||
}
|
||||
|
||||
for local, remotes := range uploads {
|
||||
for _, remote := range remotes {
|
||||
err := c.Upload(local, remote)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
// println(remote, "--->", v, "Finish download!")
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Download(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download"
|
||||
var download = map[string][]string{
|
||||
"/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"},
|
||||
"/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"},
|
||||
"/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"},
|
||||
"/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"},
|
||||
"/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"},
|
||||
"/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"},
|
||||
}
|
||||
|
||||
for remote, local := range download {
|
||||
for _, v := range local {
|
||||
err := c.Download(remote, v)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
// println(remote, "--->", v, "Finish download!")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestClient_DownloadFile(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/downloadfile"
|
||||
var download = map[string][]string{
|
||||
"/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"},
|
||||
"/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"},
|
||||
"/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"},
|
||||
"/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"},
|
||||
"/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"},
|
||||
"/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"},
|
||||
}
|
||||
|
||||
for remote, local := range download {
|
||||
for _, v := range local {
|
||||
err := c.downloadFile(remote, v)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
// println(remote, "--->", v, "Finish download!")
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestClient_DownloadDir(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/downloaddir"
|
||||
var download = map[string][]string{
|
||||
"/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"},
|
||||
"/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"},
|
||||
"/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"},
|
||||
"/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"},
|
||||
"/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"},
|
||||
"/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"},
|
||||
}
|
||||
|
||||
for remote, local := range download {
|
||||
for _, v := range local {
|
||||
err := c.downloadDir(remote, v)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
// println(remote, "--->", v, "Finish download!")
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestClient_UploadFile(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
|
||||
var uploads = map[string][]string{
|
||||
local + "null": []string{"/root/upload/file_test/null/1", "/root/upload/file_test/null/2/"},
|
||||
local + "null/": []string{"/root/upload/file_test/null/3", "/root/upload/file_test/null/4/"},
|
||||
local + "file": []string{"/root/upload/file_test/file/1", "/root/upload/file_test/file/2/"},
|
||||
local + "file/": []string{"/root/upload/file_test/file/3", "/root/upload/file_test/file/4/"},
|
||||
local + "dir": []string{"/root/upload/file_test/dir/1", "/root/upload/file_test/dir/2/"},
|
||||
local + "dir/": []string{"/root/upload/file_test/dir/3", "/root/upload/file_test/dir/4/"},
|
||||
}
|
||||
|
||||
for local, remotes := range uploads {
|
||||
for _, remote := range remotes {
|
||||
err := c.UploadFile(local, remote)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
// println(remote, "--->", v, "Finish download!")
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_UploadDir(t *testing.T) {
|
||||
c := GetClient()
|
||||
defer c.Close()
|
||||
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
|
||||
var uploads = map[string][]string{
|
||||
local + "null/": []string{"/root/upload/dir_test/null/1", "/root/upload/dir_test/null/2/"},
|
||||
local + "null/": []string{"/root/upload/dir_test/null/3", "/root/upload/dir_test/null/4/"},
|
||||
local + "file": []string{"/root/upload/dir_test/file/1", "/root/upload/dir_test/file/2/"},
|
||||
local + "file/": []string{"/root/upload/dir_test/file/3", "/root/upload/dir_test/file/4/"},
|
||||
local + "dir": []string{"/root/upload/dir_test/dir/1", "/root/upload/dir_test/dir/2/"},
|
||||
local + "dir/": []string{"/root/upload/dir_test/dir/3", "/root/upload/dir_test/dir/4/"},
|
||||
}
|
||||
|
||||
for local, remotes := range uploads {
|
||||
for _, remote := range remotes {
|
||||
err := c.UploadDir(local, remote)
|
||||
if err != nil {
|
||||
println(err.Error())
|
||||
}
|
||||
// println(remote, "--->", v, "Finish download!")
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// This is the phrase that tells us sudo is looking for a password via stdin
|
||||
const sudoPwPrompt = "sudo_password"
|
||||
|
||||
// sudoWriter is used to both combine stdout and stderr as well as
|
||||
// look for a password request from sudo.
|
||||
type sudoWriter struct {
|
||||
b bytes.Buffer
|
||||
pw string // The password to pass to sudo (if requested)
|
||||
stdin io.Writer // The writer from the ssh session
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
func (w *sudoWriter) Write(p []byte) (int, error) {
|
||||
// If we get the sudo password prompt phrase send the password via stdin
|
||||
// and don't write it to the buffer.
|
||||
if string(p) == sudoPwPrompt {
|
||||
w.stdin.Write([]byte(w.pw + "\n"))
|
||||
w.pw = "" // We don't need the password anymore so reset the string
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
w.m.Lock()
|
||||
defer w.m.Unlock()
|
||||
|
||||
return w.b.Write(p)
|
||||
}
|
||||
|
||||
// ExecSu Execute cmd via sudo. Do not include the sudo command in
|
||||
// the cmd string. For example: Client.ExecSudo("uptime", "password").
|
||||
// If you are using passwordless sudo you can use the regular Exec()
|
||||
// function.
|
||||
func (c *Client) ExecSu(cmd, passwd string) ([]byte, error) {
|
||||
session, err := c.SSHClient.NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// -n run non interactively
|
||||
// -p specify the prompt. We do this to know that sudo is asking for a passwd
|
||||
// -S Writes the prompt to StdErr and reads the password from StdIn
|
||||
cmd = "sudo -p " + sudoPwPrompt + " -S " + cmd
|
||||
|
||||
// Use the sudoRW struct to handle the interaction with sudo and capture the
|
||||
// output of the command
|
||||
w := &sudoWriter{
|
||||
pw: passwd,
|
||||
}
|
||||
w.stdin, err = session.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Combine stdout, stderr to the same writer which also looks for the sudo
|
||||
// password prompt
|
||||
session.Stdout = w
|
||||
session.Stderr = w
|
||||
|
||||
err = session.Run(cmd)
|
||||
|
||||
return w.b.Bytes(), err
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
dir file
|
|
@ -0,0 +1 @@
|
|||
this is subdir file
|
|
@ -0,0 +1 @@
|
|||
file
|
Loading…
Reference in New Issue