diff --git a/client.go b/client.go index 8746911..a3bc332 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,26 @@ type Client struct { SFTPClient *sftp.Client } +func NewDSN() (client *Client) { + return nil +} +func Connect(cnf *Config) (client *Client, err error) { + + return nil, nil +} + +func (cnf *Config) Connect() (client *Client, err error) { + + return nil, nil +} + +// Close the underlying SSH connection +func (c *Client) Close() { + c.SFTPClient.Close() + c.SSHClient.Close() + c.SSHSession.Close() +} + // New 创建SSH client func New(cnf *Config) (client *Client, err error) { clientConfig := &ssh.ClientConfig{ @@ -33,15 +53,19 @@ func New(cnf *Config) (client *Client, err error) { } // 1. privite key file - if len(cnf.KeyFiles) == 0 { - - if auth, err := AuthWithPrivateKey(KeyFile(), cnf.Passphrase); err == nil { - clientConfig.Auth = append(clientConfig.Auth, auth) - } - } else { + if len(cnf.KeyFiles) != 0 { if auth, err := AuthWithPrivateKeys(cnf.KeyFiles, cnf.Passphrase); err == nil { clientConfig.Auth = append(clientConfig.Auth, auth) } + + } else { + keypath := KeyFile() + if FileExist(keypath) { + if auth, err := AuthWithPrivateKey(keypath, cnf.Passphrase); err == nil { + clientConfig.Auth = append(clientConfig.Auth, auth) + } + } + } // 2. 密码方式 放在key之后,这样密钥失败之后可以使用Password方式 if cnf.Password != "" { @@ -149,10 +173,3 @@ func NewWithPrivateKey(Host, Port, User, Passphrase string) (client *Client, err return &Client{SSHClient: sshClient, SFTPClient: sftpClient}, nil } - -// Close the underlying SSH connection -func (c *Client) Close() { - c.SFTPClient.Close() - c.SSHClient.Close() - c.SSHSession.Close() -} diff --git a/client_test.go b/client_test.go index ced679f..caef444 100644 --- a/client_test.go +++ b/client_test.go @@ -13,7 +13,7 @@ func TestNewWithAgent(t *testing.T) { return } defer c.Close() - b, err := c.Run("id") + b, err := c.Output("id") if err != nil { fmt.Println(err) return @@ -28,7 +28,7 @@ func TestNewClient(t *testing.T) { return } defer c.Close() - b, err := c.Run("id") + b, err := c.Output("id") if err != nil { fmt.Println(err) return @@ -43,7 +43,7 @@ func TestNewWithPrivateKey(t *testing.T) { return } defer c.Close() - b, err := c.Run("id") + b, err := c.Output("id") if err != nil { fmt.Println(err) return diff --git a/example/download/main.go b/example/download/main.go index cb3c772..5285ab6 100644 --- a/example/download/main.go +++ b/example/download/main.go @@ -1,24 +1,80 @@ package main import ( + "fmt" + "log" + "os" + "path" + "time" + + "github.com/yeka/zip" + "github.com/pytool/ssh" ) +var ( + err error + // sftpClient *sftp.Client +) +var FORMAT = "2006-01-02" +var dbnames = []string{"tower", "mengyin", "pingyi", "shizhi", "tancheng", "yinan", "yishui", "feixian", "gaoxinqu", "hedong", "jingkaiqu", "junan", "luozhuang", "lanling", "lanshan", "lingang", "linshu"} + func main() { - client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") - if err != nil { - panic(err) - } - defer client.Close() - var remotedir = "/root/test/" - // download dir - var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/" - client.Download(remotedir, local) + // client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") + // if err != nil { + // panic(err) + // } + // defer client.Close() - // upload file - var remotefile = "/root/test/file" + tmp := os.TempDir() - client.Download(remotefile, local) + fmt.Println(tmp) + // var remotedir = "/root/test/" + // // download dir + // var local = "/home/ubuntu/go/src/github.com/pytool/ssh/test/download/" + // client.Download(remotedir, local) + + // // download file + // var remotefile = "/root/test/file" + // client.Download(remotefile, local) } +func Down() { + // tmp:=os.TempDir() + tmp_dir := "db_" + time.Now().Format(FORMAT) + os.Mkdir(tmp_dir, 0755) + +} +func NewZipWriter(name string) *zip.Writer { + zipname, err := os.Create(name) + if err != nil { + log.Fatalln(err) + } + return zip.NewWriter(zipname) +} +func DownLoadZip(client *ssh.Client, zw *zip.Writer, src string) { + // fmt.Println(src, "数据正在复制中,请耐心等待...") + srcFile, err := client.SFTPClient.Open(src) + if err != nil { + log.Println(err) + return + } + defer srcFile.Close() + + var localFileName = path.Base(src) + // dstFile, err := os.Create(path.Join(dst, localFileName)) + // if err != nil { + // log.Println(err) + // return + // } + // defer dstFile.Close() + w, err := zw.Encrypt(localFileName, `hangruan2017`, zip.AES256Encryption) + if err != nil { + return + } + if _, err = srcFile.WriteTo(w); err != nil { + log.Println(err) + return + } +} diff --git a/example/getdata/main.go b/example/getdata/main.go index 917f937..d6785f8 100644 --- a/example/getdata/main.go +++ b/example/getdata/main.go @@ -1,12 +1,16 @@ -package getdata +package main import ( "fmt" + "io" "log" "os" "path" + "strconv" "time" + "github.com/yeka/zip" + "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) @@ -65,24 +69,37 @@ func main() { // 用来测试的远程文件路径 和 本地文件夹 // fmt.Println(shizhi) // var localDir = "." - date_dir := "db_" + time.Now().Format(FORMAT) + var dsts []string + tmp := os.TempDir() + date_dir := path.Join(tmp, "db_"+time.Now().Format(FORMAT)) os.Mkdir(date_dir, 0755) var lzkpbi = "/docker/backup/" + time.Now().Format(FORMAT) + "_lzkp_bi_inner.zip" Down(lzkpbi, date_dir) + dsts = append(dsts, path.Join(date_dir, time.Now().Format(FORMAT)+"_lzkp_bi_inner.zip")) for _, n := range dbnames { p := "/docker/backup/" + time.Now().Format(FORMAT) + "_" + n + "_inner.zip" // fmt.Println(p) Down(p, date_dir) + dsts = append(dsts, path.Join(date_dir, time.Now().Format(FORMAT)+"_"+n+"_inner.zip")) + } + + zippass("", dsts...) // fmt.Scanln() + for _, v := range dsts { + // fmt.Println(v) + // ioutil.WriteFile(v, []byte("aaa"), 755) + os.Remove(v) + } + } func Down(src, dst string) { - fmt.Println(src, "数据正在复制中,请耐心等待...") + // fmt.Println(src, "数据正在复制中,请耐心等待...") srcFile, err := sftpClient.Open(src) if err != nil { - log.Println(err) + // log.Println(err) return } defer srcFile.Close() @@ -90,16 +107,41 @@ func Down(src, dst string) { var localFileName = path.Base(src) dstFile, err := os.Create(path.Join(dst, localFileName)) if err != nil { - log.Println(err) + // log.Println(err) return } defer dstFile.Close() if _, err = srcFile.WriteTo(dstFile); err != nil { - log.Println(err) + // log.Println(err) return } - fmt.Println(src, "数据复制完成!") + // fmt.Println(src, "数据复制完成!") } + +func zippass(dst string, src ...string) { + fzip, err := os.Create(`D:/待测试数据.zip`) + if err != nil { + log.Fatalln(err) + } + zipw := zip.NewWriter(fzip) + defer zipw.Close() + for i, n := range src { + w, err := zipw.Encrypt(strconv.Itoa(i), `hangruan2017`, zip.AES256Encryption) + if err != nil { + log.Fatal(err) + } + f, err := os.Open(n) + if err != nil { + return + } + + _, err = io.Copy(w, f) + if err != nil { + log.Fatal(err) + } + } + zipw.Flush() +} diff --git a/example/mysql-proxy/main.go b/example/mysql-proxy/main.go new file mode 100644 index 0000000..7a7616c --- /dev/null +++ b/example/mysql-proxy/main.go @@ -0,0 +1,84 @@ +package main + +import ( + "database/sql" + "fmt" + "net" + + "github.com/pytool/ssh" + + "github.com/go-sql-driver/mysql" +) + +var dsn = `lzkp:yqhtfjzm@tcp(192.168.5.100:3306)/?parseTime=true&loc=Local` +var DBNAME = "shizhi" +var db *sql.DB + +func Prepare() { + var err error + db, err = sql.Open("mysql", dsn) + if err != nil { + // return FAIL, fmt.Errorf("Unable to open connection to database server: %s", err.Error()) + fmt.Print("") + } + // defer db.Close() + err = db.Ping() + if err != nil { + // return FAIL, fmt.Errorf("Unable to ping database server: %s", err.Error()) + fmt.Print("") + } + // _, err = db.Exec("CREATE DATABASE IF NOT EXISTS" + DBNAME) + // if err != nil { + // // return FAIL, fmt.Errorf("Unable to create database %s: %s", DBNAME, err.Error()) + // fmt.Print("") + // } + // defer db.Exec("DROP DATABASE dbgrep") + _, err = db.Exec("use " + DBNAME) + if err != nil { + fmt.Errorf("Unable to select database %s: %s", DBNAME, err.Error()) + } + // return m.Run(), nil +} + +func main() { + // SSH的连接参数: + config := ssh.Default.WithPassword("HR2018!!").WithHost("192.168.5.157") + client, err := ssh.New(config) + // client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") + if err != nil { + panic(err) + } + defer client.Close() + fmt.Println(client.Output("id")) + + // 1. 注册自定义的 Dial 命名为:mysql+ssh + // Now we register the ViaSSHDialer with the ssh connection as a parameter + mysql.RegisterDial("mysql+ssh", func(addr string) (net.Conn, error) { + return client.SSHClient.Dial("tcp", addr) + }) + + // DB数据库的连接参数: + dbUser := "root" // DB username + dbPass := "" // DB Password + dbHost := "localhost:3306" // DB Hostname/IP + dbName := "shizhi" // Database name + // 2. 使用自定义命名为:mysql+ssh的 Dial 进行mysql连接 + // And now we can use our new driver with the regular mysql connection string tunneled through the SSH connection + if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mysql+ssh(%s)/%s", dbUser, dbPass, dbHost, dbName)); err == nil { + + if rows, err := db.Query("SELECT user, host FROM mysql.user "); err == nil { + for rows.Next() { + var id string + var name string + rows.Scan(&id, &name) + fmt.Printf("ID: %s\tName: %s\n", id, name) + } + rows.Close() + } else { + fmt.Printf("Failure: %s", err.Error()) + } + + db.Close() + fmt.Printf("Successfully connected to the db\n") + } +} diff --git a/example/runetl/main.go b/example/runetl/main.go index 37fb721..1b3976d 100644 --- a/example/runetl/main.go +++ b/example/runetl/main.go @@ -7,7 +7,7 @@ import ( ) func main() { - config := ssh.Default.WithHost("192.168.5.157").WithPassword("HR2018!!") + config := ssh.Default.WithHost("15.14.12.153").WithPassword("HR2018!!") // config.Host = "15.14.12.153" client, err := ssh.New(config) // client, err := ssh.NewClient("localhost", "22", "root", "ubuntu") diff --git a/sftp.go b/sftp.go index ebb477f..8ebaab8 100644 --- a/sftp.go +++ b/sftp.go @@ -57,6 +57,7 @@ func (c *Client) downloadFile(remoteFile, local string) error { } else { localFile = local } + localFile = filepath.ToSlash(localFile) if c.Size(remoteFile) > 1000 { rsum := c.Md5File(remoteFile) ioutil.WriteFile(localFile+".md5", []byte(rsum), 755) diff --git a/sftp_is_test.go b/sftp_check_test.go similarity index 100% rename from sftp_is_test.go rename to sftp_check_test.go diff --git a/ssh.go b/ssh.go index f6770a2..bb82879 100644 --- a/ssh.go +++ b/ssh.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "os" "path/filepath" ) @@ -17,7 +18,7 @@ func (c *Client) Run(cmd string) { } defer session.Close() - session.Start(cmd) + err = session.Start(cmd) if err != nil { fmt.Printf("exec command:%v error:%v\n", cmd, err) } @@ -30,6 +31,29 @@ func (c *Client) Run(cmd string) { return } +//Exec Execute cmd on the remote host and bind stderr and stdout +func (c *Client) Exec1(cmd string) error { + + // New Session + session, err := c.SSHClient.NewSession() + if err != nil { + return err + } + defer session.Close() + + // go func() { + // time.Sleep(2419200 * time.Second) + // conn.Close() + // }() + + session.Stdout = os.Stdout + session.Stderr = os.Stderr + err = session.Run(cmd) + session.Close() + return nil + +} + //Exec Execute cmd on the remote host and bind stderr and stdout func (c *Client) Exec(cmd string) error { session, err := c.SSHClient.NewSession()