golang lib ssh

This commit is contained in:
rinetd 2018-10-30 17:45:30 +08:00
parent 22f8a4e050
commit a2d60ca1f3
14 changed files with 1245 additions and 0 deletions

95
README.md Normal file
View File

@ -0,0 +1,95 @@
## 项目简介
本项目是基于golang标准库 ssh 和 sftp 开发
本项目是对标准库进行一个简单的高层封装,使得可以在在 Windows Linux Mac 上非常容易的执行 ssh 命令,
以及文件,文件夹的上传,下载等操作.
文件上传下载模仿rsync方式: 只和源有关.
// rsync -av src/ dst ./src/* --> /root/dst/*
// rsync -av src/ dst/ ./src/* --> /root/dst/*
// rsync -av src dst ./src/* --> /root/dst/src/*
// rsync -av src dst/ ./src/* --> /root/dst/src/*
## Example
### 在远程执行ssh命令
```go
package main
import (
"fmt"
"github.com/pytool/ssh"
)
func main() {
c, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
if err != nil {
panic(err)
}
defer c.Close()
output, err := c.Exec("uptime")
if err != nil {
panic(err)
}
fmt.Printf("Uptime: %s\n", output)
}
```
### 文件下载
```go
package main
import (
"github.com/pytool/ssh"
)
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
if err != nil {
panic(err)
}
defer client.Close()
var remotedir = "/root/test/"
// download dir
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/"
client.Download(remotedir, local)
// upload file
var remotefile = "/root/test/file"
client.Download(remotefile, local)
}
```
### 文件上传
```go
package main
import (
"github.com/pytool/ssh"
)
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
if err != nil {
panic(err)
}
defer client.Close()
var remotedir = "/root/test/"
// upload dir
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
client.Upload(local, remotedir)
// upload file
local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/file"
client.Upload(local, remotedir)
}
```

160
auth.go Normal file
View File

@ -0,0 +1,160 @@
package ssh
import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"strings"
"syscall"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
)
// An implementation of ssh.KeyboardInteractiveChallenge that simply sends
// back the password for all questions. The questions are logged.
func passwordKeyboardInteractive(password string) ssh.KeyboardInteractiveChallenge {
return func(user, instruction string, questions []string, echos []bool) ([]string, error) {
// log.Printf("Keyboard interactive challenge: ")
// log.Printf("-- User: %s", user)
// log.Printf("-- Instructions: %s", instruction)
// for i, question := range questions {
// log.Printf("-- Question %d: %s", i+1, question)
// }
// Just send the password back for all questions
answers := make([]string, len(questions))
for i := range answers {
answers[i] = password
}
return answers, nil
}
}
// WithKeyboardPassword Generate a password-auth'd ssh ClientConfig
func WithKeyboardPassword(password string) (ssh.AuthMethod, error) {
return ssh.KeyboardInteractive(passwordKeyboardInteractive(password)), nil
}
// WithPassword Generate a password-auth'd ssh ClientConfig
func WithPassword(password string) (ssh.AuthMethod, error) {
return ssh.Password(password), nil
}
// WithAgent use already authed user
func WithAgent() (ssh.AuthMethod, error) {
sock := os.Getenv("SSH_AUTH_SOCK")
if sock != "" {
// fmt.Println(errors.New("Agent Disabled"))
return nil, errors.New("Agent Disabled")
}
socks, err := net.Dial("unix", sock)
if err != nil {
fmt.Println(err)
return nil, err
}
// 1. 返回Signers函数的结果
agent := agent.NewClient(socks)
signers, err := agent.Signers()
return ssh.PublicKeys(signers...), nil
// 2. 返回Signers函数
// getSigners := agent.NewClient(socks).Signers
// return ssh.PublicKeysCallback(getSigners), nil
// 3.简写方式
// if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
// return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
// }
// return nil
}
// WithPrivateKeys 设置多个 ~/.ssh/id_rsa
func WithPrivateKeys(keyFiles []string, password string) (ssh.AuthMethod, error) {
var signers []ssh.Signer
for _, key := range keyFiles {
buffer, err := ioutil.ReadFile(key)
if err != nil {
println(err.Error())
// return
}
signer, err := ssh.ParsePrivateKeyWithPassphrase([]byte(buffer), []byte(password))
if err != nil {
println(err.Error())
} else {
signers = append(signers, signer)
}
}
if signers == nil {
return nil, errors.New("WithPrivateKeys: no keyfiles input")
}
return ssh.PublicKeys(signers...), nil
}
// WithPrivateKey 自动监测是否带有密码
func WithPrivateKey(keyfile string, password string) (ssh.AuthMethod, error) {
pemBytes, err := ioutil.ReadFile(keyfile)
if err != nil {
println(err.Error())
return nil, err
}
var signer ssh.Signer
signer, err = ssh.ParsePrivateKey(pemBytes)
if err != nil {
if strings.Contains(err.Error(), "cannot decode encrypted private keys") {
if signer, err = ssh.ParsePrivateKeyWithPassphrase(pemBytes, []byte(password)); err == nil {
return ssh.PublicKeys(signer), nil
}
}
}
return nil, err
}
// WithPrivateKeyString 直接通过字符串
func WithPrivateKeyString(key string, password string) (ssh.AuthMethod, error) {
var signer ssh.Signer
var err error
if password == "" {
signer, err = ssh.ParsePrivateKey([]byte(key))
} else {
signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(key), []byte(password))
}
if err != nil {
println(err.Error())
return nil, err
}
return ssh.PublicKeys(signer), nil
}
// WithPrivateKeyTerminal 通过终端读取带密码的 PublicKey
func WithPrivateKeyTerminal(keyfile string) (ssh.AuthMethod, error) {
// fmt.Fprintf(os.Stderr, "This SSH key is encrypted. Please enter passphrase for key '%s':", priv.path)
passphrase, err := terminal.ReadPassword(int(syscall.Stdin))
if err != nil {
println(err.Error())
return nil, err
}
fmt.Fprintln(os.Stderr)
pemBytes, err := ioutil.ReadFile(keyfile)
if err != nil {
println(err.Error())
return nil, err
}
signer, err := ssh.ParsePrivateKeyWithPassphrase(pemBytes, passphrase)
if err != nil {
fmt.Println(err)
return nil, err
}
return ssh.PublicKeys(signer), nil
}

114
client.go Normal file
View File

@ -0,0 +1,114 @@
package ssh
import (
"errors"
"net"
"os"
"strconv"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const DefaultTimeout = 30 * time.Second
type Client struct {
*Config
SSHClient *ssh.Client
SFTPClient *sftp.Client
}
// NewClient 根据配置
func NewClient(user, host, port, password string) (client *Client, err error) {
p, err := strconv.Atoi(port)
if err == nil || p == 0 {
p = 22
}
if user == "" {
user = "root"
}
var config = &Config{
User: user,
Host: host,
Port: p,
Password: password,
// KeyFiles: []string{"~/.ssh/id_rsa"},
}
return New(config)
}
// New 创建SSH client
func New(config *Config) (client *Client, err error) {
clientConfig := &ssh.ClientConfig{
User: config.User,
Timeout: DefaultTimeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// 2. 密码方式
if config.Password != "" {
clientConfig.Auth = append(clientConfig.Auth, ssh.Password(config.Password))
}
// 3. privite key file
if len(config.KeyFiles) == 0 {
keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
if auth, err := WithPrivateKey(keyPath, config.Password); err != nil {
clientConfig.Auth = append(clientConfig.Auth, auth)
}
} else {
if auth, err := WithPrivateKeys(config.KeyFiles, config.Password); err != nil {
clientConfig.Auth = append(clientConfig.Auth, auth)
}
}
// 1. agent
if auth, err := WithAgent(); err != nil {
clientConfig.Auth = append(clientConfig.Auth, auth)
}
if config.Port == 0 {
config.Port = 22
}
// hostPort := config.Host + ":" + strconv.Itoa(config.Port)
sshClient, err := ssh.Dial("tcp", net.JoinHostPort(config.Host, strconv.Itoa(config.Port)), clientConfig)
if err != nil {
return client, errors.New("Failed to dial ssh: " + err.Error())
}
// create sftp client
var sftpClient *sftp.Client
if sftpClient, err = sftp.NewClient(sshClient); err != nil {
return client, errors.New("Failed to conn sftp: " + err.Error())
}
return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil
}
// Execute cmd on the remote host and return stderr and stdout
func (c *Client) Exec(cmd string) ([]byte, error) {
session, err := c.SSHClient.NewSession()
if err != nil {
return nil, err
}
defer session.Close()
return session.CombinedOutput(cmd)
}
// Close the underlying SSH connection
func (c *Client) Close() {
c.SFTPClient.Close()
c.SSHClient.Close()
}
func addPortToHost(host string) string {
_, _, err := net.SplitHostPort(host)
// We got an error so blindly try to add a port number
if err != nil {
return net.JoinHostPort(host, "22")
}
return host
}

24
cmd/main.go Normal file
View File

@ -0,0 +1,24 @@
package main
import (
"fmt"
"github.com/pytool/ssh"
)
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
if err != nil {
panic(err)
}
defer client.Close()
output, err := client.Exec("uptime")
if err != nil {
panic(err)
}
fmt.Printf("Uptime: %s\n", output)
}

80
config.go Normal file
View File

@ -0,0 +1,80 @@
package ssh
import (
"time"
)
type Config struct {
User string
Host string
Port int
Password string
KeyFiles []string
// DisableAgentForwarding, if true, will not forward the SSH agent.
DisableAgentForwarding bool
// HandshakeTimeout limits the amount of time we'll wait to handshake before
// saying the connection failed.
HandshakeTimeout time.Duration
// KeepAliveInterval sets how often we send a channel request to the
// server. A value < 0 disables.
KeepAliveInterval time.Duration
// Timeout is how long to wait for a read or write to succeed.
Timeout time.Duration
}
var DefaultConfig = &Config{
User: "root",
Port: 22,
KeyFiles: []string{"~/.ssh/id_rsa"},
}
//
func (c *Client) WithUser(user string) *Client {
if user == "" {
user = "root"
}
c.User = user
return c
}
//
func (c *Client) WithHost(host string) *Client {
if host == "" {
host = "localhost"
}
c.Host = host
return c
}
func (c *Client) WithPassword(password string) *Client {
c.Password = password
return c
}
//
func (c *Client) SetKeys(keyfiles []string) *Client {
if keyfiles == nil {
return c
}
t := make([]string, len(keyfiles))
copy(t, keyfiles)
c.KeyFiles = t
return c
}
//
func (c *Client) WithKey(keyfile string) *Client {
if keyfile == "" {
keyfile = "~/.ssh/id_rsa"
}
for _, s := range c.KeyFiles {
if s == keyfile {
return c
}
}
c.KeyFiles = append(c.KeyFiles, keyfile)
return c
}

24
example/get/main.go Normal file
View File

@ -0,0 +1,24 @@
package main
import (
"github.com/pytool/ssh"
)
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
if err != nil {
panic(err)
}
defer client.Close()
var remotedir = "/root/test/"
// download dir
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/"
client.Download(remotedir, local)
// upload file
var remotefile = "/root/test/file"
client.Download(remotefile, local)
}

22
example/put/main.go Normal file
View File

@ -0,0 +1,22 @@
package main
import (
"github.com/pytool/ssh"
)
func main() {
client, err := ssh.NewClient("root", "localhost", "22", "ubuntu")
if err != nil {
panic(err)
}
defer client.Close()
var remotedir = "/root/test/"
// upload dir
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
client.Upload(local, remotedir)
// upload file
local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/file"
client.Upload(local, remotedir)
}

399
sftp.go Normal file
View File

@ -0,0 +1,399 @@
package ssh
import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"strings"
)
// Upload 上传本地文件 local 到sftp远程目录 remote like rsync
// rsync -av src/ dst ./src/* --> /root/dst/*
// rsync -av src/ dst/ ./src/* --> /root/dst/*
// rsync -av src dst ./src/* --> /root/dst/src/*
// rsync -av src dst/ ./src/* --> /root/dst/src/*
func (c *Client) Upload(local string, remote string) (err error) {
// var localDir, localFile, remoteDir, remoteFile string
info, err := os.Stat(local)
if err != nil {
return errors.New("本地文件不存在或格式错误 Upload(\"" + local + "\") 跳过上传")
}
if info.IsDir() {
return c.UploadDir(local, remote)
}
return c.UploadFile(local, remote)
}
// Download 下载sftp远程文件 remote 到本地 local like rsync
func (c *Client) Download(remote string, local string) (err error) {
if c.IsNotExist(strings.TrimSuffix(remote, "/")) {
return errors.New("文件不存在,跳过文件下载 \"" + remote + "\" ")
}
if c.IsDir(remote) {
// return errors.New("检测到远程是文件不是目录 \"" + remote + "\" 跳过下载")
return c.downloadDir(remote, local)
}
return c.downloadFile(remote, local)
}
// downloadFile a file from the remote server like cp
func (c *Client) downloadFile(remoteFile, local string) error {
// remoteFile = strings.TrimSuffix(remoteFile, "/")
if !c.IsFile(remoteFile) {
return errors.New("文件不存在或不是文件, 跳过目录下载 downloadFile(" + remoteFile + ")")
}
var localFile string
if local[len(local)-1] == '/' {
localFile = filepath.Join(local, filepath.Base(remoteFile))
} else {
localFile = local
}
if err := os.MkdirAll(filepath.Dir(localFile), os.ModePerm); err != nil {
// fmt.Println(err)
return err
}
r, err := c.SFTPClient.Open(remoteFile)
if err != nil {
return err
}
defer r.Close()
l, err := os.Create(localFile)
if err != nil {
return err
}
defer l.Close()
_, err = io.Copy(l, r)
return err
}
// downloadDir from remote dir to local dir like rsync
// rsync -av src/ dst ./src/* --> /root/dst/*
// rsync -av src/ dst/ ./src/* --> /root/dst/*
// rsync -av src dst ./src/* --> /root/dst/src/*
// rsync -av src dst/ ./src/* --> /root/dst/src/*
func (c *Client) downloadDir(remote, local string) error {
var localDir, remoteDir string
if !c.IsDir(remote) {
return errors.New("目录不存在或不是目录, 跳过 downloadDir(" + remote + ")")
}
remoteDir = remote
if remote[len(remote)-1] == '/' {
localDir = local
} else {
localDir = path.Join(local, path.Base(remote))
}
walker := c.SFTPClient.Walk(remoteDir)
for walker.Step() {
if err := walker.Err(); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
info := walker.Stat()
relPath, err := filepath.Rel(remoteDir, walker.Path())
if err != nil {
return err
}
localPath := filepath.ToSlash(filepath.Join(localDir, relPath))
// if we have something at the download path delete it if it is a directory
// and the remote is a file and vice a versa
localInfo, err := os.Stat(localPath)
if os.IsExist(err) {
if localInfo.IsDir() {
if info.IsDir() {
continue
}
err = os.RemoveAll(localPath)
if err != nil {
return err
}
} else if info.IsDir() {
err = os.Remove(localPath)
if err != nil {
return err
}
}
}
if info.IsDir() {
err = os.MkdirAll(localPath, os.ModePerm)
if err != nil {
return err
}
continue
}
c.downloadFile(walker.Path(), localPath)
}
return nil
}
//UploadFile 上传本地文件 localFile 到sftp远程目录 remote
func (c *Client) UploadFile(localFile, remote string) error {
// localFile = strings.TrimSuffix(localFile, "/")
info, err := os.Stat(localFile)
if err != nil || info.IsDir() {
return errors.New("本地文件不存在,或是不是文件 UploadFile(\"" + localFile + "\") 跳过上传")
}
l, err := os.Open(localFile)
if err != nil {
return err
}
defer l.Close()
var remoteFile, remoteDir string
if remote[len(remote)-1] == '/' {
remoteFile = filepath.Join(remote, filepath.Base(localFile))
remoteDir = remote
} else {
remoteFile = remote
remoteDir = filepath.Dir(remoteFile)
}
// 目录不存在,则创建 remoteDir
if _, err := c.SFTPClient.Stat(remoteDir); err != nil {
c.MkdirAll(remoteDir)
}
r, err := c.SFTPClient.Create(remoteFile)
if err != nil {
return err
}
_, err = io.Copy(r, l)
return err
}
// UploadDir files without checking diff status
func (c *Client) UploadDir(localDir string, remoteDir string) (err error) {
// defer func() {
// if err != nil {
// err = errors.New("UploadDir " + err.Error())
// }
// }()
// 本地输入检测,必须是目录
info, err := os.Stat(localDir)
if err != nil || !info.IsDir() {
return errors.New("本地目录不存在或不是目录 UploadDir(\"" + localDir + "\") 跳过上传")
}
// 模仿 rsync localDir不以'/'结尾,则创建尾目录
if localDir[len(localDir)-1] != '/' {
remoteDir = filepath.Join(remoteDir, filepath.Base(localDir))
}
// fmt.Println("remoteDir", remoteDir)
rootDst := strings.TrimSuffix(remoteDir, "/")
if c.IsFile(rootDst) {
c.SFTPClient.Remove(rootDst)
}
walkFunc := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Calculate the final destination using the
// base source and root destination
relSrc, err := filepath.Rel(localDir, path)
if err != nil {
return err
}
finalDst := filepath.Join(rootDst, relSrc)
// In Windows, Join uses backslashes which we don't want to get
// to the sftp server
finalDst = filepath.ToSlash(finalDst)
// Skip the creation of the target destination directory since
// it should exist and we might not even own it
if finalDst == remoteDir {
return nil
fmt.Println("skip", remoteDir, "--->", finalDst)
}
if info.IsDir() {
c.MkdirAll(finalDst)
// err = c.SFTPClient.Mkdir(finalDst)
// fmt.Println(err)
// if err := c.SFTPClient.Mkdir(finalDst); err != nil {
// // Do not consider it an error if the directory existed
// remoteFi, fiErr := c.SFTPClient.Lstat(finalDst)
// if fiErr != nil || !remoteFi.IsDir() {
// return err
// }
// }
// return err
} else {
// f, err := os.Open(path)
// if err != nil {
// return err
// }
// defer f.Close()
return c.UploadFile(path, finalDst)
}
return nil
}
return filepath.Walk(localDir, walkFunc)
}
// Remove a file from the remote server
func (c *Client) Remove(path string) error {
return c.SFTPClient.Remove(path)
}
// RemoveDirectory Remove a directory from the remote server
func (c *Client) RemoveDirectory(path string) error {
return c.SFTPClient.RemoveDirectory(path)
}
// ReadAll Read a remote file and return the contents.
func (c *Client) ReadAll(filepath string) ([]byte, error) {
file, err := c.SFTPClient.Open(filepath)
if err != nil {
return nil, err
}
defer file.Close()
return ioutil.ReadAll(file)
}
//FileExist 文件是否存在
func (c *Client) FileExist(filepath string) (bool, error) {
if _, err := c.SFTPClient.Stat(filepath); err != nil {
return false, err
}
return true, nil
}
func (c *Client) RemoveFile(remoteFile string) error {
return c.SFTPClient.Remove(remoteFile)
}
func (c *Client) RemoveDir(remoteDir string) error {
remoteFiles, err := c.SFTPClient.ReadDir(remoteDir)
if err != nil {
log.Printf("remove remote dir: %s err: %v\n", remoteDir, err)
return err
}
for _, file := range remoteFiles {
subRemovePath := path.Join(remoteDir, file.Name())
if file.IsDir() {
c.RemoveDir(subRemovePath)
} else {
c.RemoveFile(subRemovePath)
}
}
c.SFTPClient.RemoveDirectory(remoteDir) //must empty dir to remove
log.Printf("remove remote dir: %s ok\n", remoteDir)
return nil
}
//RemoveAll 递归删除目录,文件
func (c *Client) RemoveAll(remoteDir string) error {
c.RemoveDir(remoteDir)
return nil
}
//MkdirAll 创建目录,递归
func (c *Client) MkdirAll(dirpath string) error {
parentDir := filepath.Dir(dirpath)
_, err := c.SFTPClient.Stat(parentDir)
if err != nil {
if err.Error() == "file does not exist" {
err := c.MkdirAll(parentDir)
if err != nil {
return err
}
} else {
return err
}
}
err = c.SFTPClient.Mkdir(dirpath)
if err != nil {
return err
}
return nil
}
func MkdirAll(path string) error {
// 检测文件夹是否存在 若不存在 创建文件夹
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return os.MkdirAll(path, os.ModePerm)
}
}
return nil
}
func (c *Client) Mkdir(path string, fi os.FileInfo) error {
log.Printf("[DEBUG] sftp: creating dir %s", path)
if err := c.SFTPClient.Mkdir(path); err != nil {
// Do not consider it an error if the directory existed
remoteFi, fiErr := c.SFTPClient.Lstat(path)
if fiErr != nil || !remoteFi.IsDir() {
return err
}
}
mode := fi.Mode().Perm()
if err := c.SFTPClient.Chmod(path, mode); err != nil {
return err
}
return nil
}
//IsDir 检查远程是否是个目录
func (c *Client) IsDir(path string) bool {
// 检查远程是文件还是目录
info, err := c.SFTPClient.Stat(path)
if err == nil && info.IsDir() {
return true
}
return false
}
//IsFile 检查远程是否是个文件
func (c *Client) IsFile(path string) bool {
info, err := c.SFTPClient.Stat(path)
if err == nil && !info.IsDir() {
return true
}
return false
}
//IsNotExist 检查远程是文件是否不存在
func (c *Client) IsNotExist(path string) bool {
_, err := c.SFTPClient.Stat(path)
return err != nil
}
//IsExist 检查远程是文件是否存在
func (c *Client) IsExist(path string) bool {
_, err := c.SFTPClient.Stat(path)
return err == nil
}

61
sftp_is_test.go Normal file
View File

@ -0,0 +1,61 @@
package ssh
import (
"testing"
)
func TestClient_IsCheck(t *testing.T) {
c := GetClient()
defer c.Close()
var remotes = []string{
"/root/test/notExist",
"/root/test/notExist/",
"/root/test/file",
"/root/test/file/", // 不存在
"/root/test/dir",
"/root/test/dir/",
}
// /root/test/file 存在
// /root/test/file/ 不存在
// /root/test/dir 存在
// /root/test/dir/ 存在
for _, v := range remotes {
is := c.IsExist(v)
if is {
println(v, "\t存在")
} else {
println(v, "\t不存在")
}
}
// /root/test/file 不是一个目录
// /root/test/file/ 不是一个目录
// /root/test/dir 是一个目录
// /root/test/dir/ 是一个目录
println()
for _, v := range remotes {
isdir := c.IsDir(v)
if isdir {
println(v, "\t是一个目录")
} else {
println(v, "\t不是一个目录")
}
}
// /root/test/file 是一个文件
// /root/test/file/ 不是一个文件
// /root/test/dir 不是一个文件
// /root/test/dir/ 不是一个文件
println()
for _, v := range remotes {
isfile := c.IsFile(v)
if isfile {
println(v, "\t是一个文件")
} else {
println(v, "\t不是一个文件")
}
}
}

193
sftp_test.go Normal file
View File

@ -0,0 +1,193 @@
package ssh
import (
"sync"
"testing"
)
func GetClient() *Client {
var (
once = sync.Once{}
c = &Client{}
err error
)
once.Do(func() {
c, err = NewClient("root", "localhost", "22", "ubuntu")
})
if err != nil {
panic(err)
}
return c
}
// func TestClient_RemoveAll(t *testing.T) {
// c := GetClient()
// defer c.Close()
// var remotedir = "/root/test/"
// fmt.Println(c.RemoveAll(remotedir))
// }
func TestClient_Init(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
var remotedir = "/root/test/"
c.RemoveAll("/root/upload/")
err := c.Upload(local, remotedir)
if err != nil {
println("[Upload]", err.Error())
}
}
func TestClient_Upload(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
var uploads = map[string][]string{
local + "null/": []string{"/root/upload/test/null/1", "/root/upload/test/null/2/"},
local + "null/": []string{"/root/upload/test/null/3", "/root/upload/test/null/4/"},
local + "file": []string{"/root/upload/test/file/1", "/root/upload/test/file/2/"},
local + "file/": []string{"/root/upload/test/file/3", "/root/upload/test/file/4/"},
local + "dir": []string{"/root/upload/test/dir/1", "/root/upload/test/dir/2/"},
local + "dir/": []string{"/root/upload/test/dir/3", "/root/upload/test/dir/4/"},
}
for local, remotes := range uploads {
for _, remote := range remotes {
err := c.Upload(local, remote)
if err != nil {
println(err.Error())
}
// println(remote, "--->", v, "Finish download!")
}
}
}
func TestClient_Download(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download"
var download = map[string][]string{
"/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"},
"/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"},
"/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"},
"/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"},
"/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"},
"/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"},
}
for remote, local := range download {
for _, v := range local {
err := c.Download(remote, v)
if err != nil {
println(err.Error())
}
// println(remote, "--->", v, "Finish download!")
}
}
}
func TestClient_DownloadFile(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/downloadfile"
var download = map[string][]string{
"/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"},
"/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"},
"/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"},
"/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"},
"/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"},
"/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"},
}
for remote, local := range download {
for _, v := range local {
err := c.downloadFile(remote, v)
if err != nil {
println(err.Error())
}
// println(remote, "--->", v, "Finish download!")
}
}
}
func TestClient_DownloadDir(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/downloaddir"
var download = map[string][]string{
"/root/test/notExist": []string{local + "/localNotExist/null/1", local + "/localNotExist/null/2/"},
"/root/test/notExist/": []string{local + "/localNotExist/null/3", local + "/localNotExist/null/4/"},
"/root/test/file": []string{local + "/localNotExist/file/1", local + "/localNotExist/file/2/"},
"/root/test/file/": []string{local + "/localNotExist/file/3", local + "/localNotExist/file/4/"},
"/root/test/dir": []string{local + "/localNotExist/dir/1", local + "/localNotExist/dir/2/"},
"/root/test/dir/": []string{local + "/localNotExist/dir/3", local + "/localNotExist/dir/4/"},
}
for remote, local := range download {
for _, v := range local {
err := c.downloadDir(remote, v)
if err != nil {
println(err.Error())
}
// println(remote, "--->", v, "Finish download!")
}
}
}
func TestClient_UploadFile(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
var uploads = map[string][]string{
local + "null": []string{"/root/upload/file_test/null/1", "/root/upload/file_test/null/2/"},
local + "null/": []string{"/root/upload/file_test/null/3", "/root/upload/file_test/null/4/"},
local + "file": []string{"/root/upload/file_test/file/1", "/root/upload/file_test/file/2/"},
local + "file/": []string{"/root/upload/file_test/file/3", "/root/upload/file_test/file/4/"},
local + "dir": []string{"/root/upload/file_test/dir/1", "/root/upload/file_test/dir/2/"},
local + "dir/": []string{"/root/upload/file_test/dir/3", "/root/upload/file_test/dir/4/"},
}
for local, remotes := range uploads {
for _, remote := range remotes {
err := c.UploadFile(local, remote)
if err != nil {
println(err.Error())
}
// println(remote, "--->", v, "Finish download!")
}
}
}
func TestClient_UploadDir(t *testing.T) {
c := GetClient()
defer c.Close()
var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/upload/"
var uploads = map[string][]string{
local + "null/": []string{"/root/upload/dir_test/null/1", "/root/upload/dir_test/null/2/"},
local + "null/": []string{"/root/upload/dir_test/null/3", "/root/upload/dir_test/null/4/"},
local + "file": []string{"/root/upload/dir_test/file/1", "/root/upload/dir_test/file/2/"},
local + "file/": []string{"/root/upload/dir_test/file/3", "/root/upload/dir_test/file/4/"},
local + "dir": []string{"/root/upload/dir_test/dir/1", "/root/upload/dir_test/dir/2/"},
local + "dir/": []string{"/root/upload/dir_test/dir/3", "/root/upload/dir_test/dir/4/"},
}
for local, remotes := range uploads {
for _, remote := range remotes {
err := c.UploadDir(local, remote)
if err != nil {
println(err.Error())
}
// println(remote, "--->", v, "Finish download!")
}
}
}

70
sudo.go Normal file
View File

@ -0,0 +1,70 @@
package ssh
import (
"bytes"
"io"
"sync"
)
// This is the phrase that tells us sudo is looking for a password via stdin
const sudoPwPrompt = "sudo_password"
// sudoWriter is used to both combine stdout and stderr as well as
// look for a password request from sudo.
type sudoWriter struct {
b bytes.Buffer
pw string // The password to pass to sudo (if requested)
stdin io.Writer // The writer from the ssh session
m sync.Mutex
}
func (w *sudoWriter) Write(p []byte) (int, error) {
// If we get the sudo password prompt phrase send the password via stdin
// and don't write it to the buffer.
if string(p) == sudoPwPrompt {
w.stdin.Write([]byte(w.pw + "\n"))
w.pw = "" // We don't need the password anymore so reset the string
return len(p), nil
}
w.m.Lock()
defer w.m.Unlock()
return w.b.Write(p)
}
// ExecSu Execute cmd via sudo. Do not include the sudo command in
// the cmd string. For example: Client.ExecSudo("uptime", "password").
// If you are using passwordless sudo you can use the regular Exec()
// function.
func (c *Client) ExecSu(cmd, passwd string) ([]byte, error) {
session, err := c.SSHClient.NewSession()
if err != nil {
return nil, err
}
defer session.Close()
// -n run non interactively
// -p specify the prompt. We do this to know that sudo is asking for a passwd
// -S Writes the prompt to StdErr and reads the password from StdIn
cmd = "sudo -p " + sudoPwPrompt + " -S " + cmd
// Use the sudoRW struct to handle the interaction with sudo and capture the
// output of the command
w := &sudoWriter{
pw: passwd,
}
w.stdin, err = session.StdinPipe()
if err != nil {
return nil, err
}
// Combine stdout, stderr to the same writer which also looks for the sudo
// password prompt
session.Stdout = w
session.Stderr = w
err = session.Run(cmd)
return w.b.Bytes(), err
}

1
test/upload/dir/file Normal file
View File

@ -0,0 +1 @@
dir file

View File

@ -0,0 +1 @@
this is subdir file

1
test/upload/file Normal file
View File

@ -0,0 +1 @@
file