diff --git a/db/files.go b/db/files.go index 9e57bd8..3c549e4 100644 --- a/db/files.go +++ b/db/files.go @@ -7,15 +7,8 @@ import ( // AddFile is used to add the md5 of a file name which is uploaded to our application // this will enable us to randomize the URL without worrying about the file names func AddFile(fileName, token string) error { - SQL, err := database.Prepare("insert into files values(?,?)") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - - if err != nil { - log.Println(err) - } + SQL := database.prepare("insert into files values(?,?)") + tx := database.begin() _, err = tx.Stmt(SQL).Exec(fileName, token) if err != nil { log.Println(err) @@ -30,7 +23,7 @@ func AddFile(fileName, token string) error { func GetFileName(token string) (string, error) { sql := "select name from files where autoName=?" var fileName string - rows, err := database.Query(sql, fileName) + rows := database.query(sql, fileName) if rows.Next() { err := rows.Scan(&fileName) if err != nil { diff --git a/db/tasks.go b/db/tasks.go index 03d3767..2b911b2 100644 --- a/db/tasks.go +++ b/db/tasks.go @@ -2,19 +2,49 @@ package db import ( "database/sql" - _ "github.com/mattn/go-sqlite3" //we want to use sqlite natively - md "github.com/shurcooL/github_flavored_markdown" - "github.com/thewhitetulip/Tasks/types" "log" "strings" "time" + + _ "github.com/mattn/go-sqlite3" //we want to use sqlite natively + md "github.com/shurcooL/github_flavored_markdown" + "github.com/thewhitetulip/Tasks/types" ) -var database *sql.DB +var database Database var err error +//Database encapsulates database +type Database struct { + db *sql.DB +} + +func (db Database) begin() (tx *sql.Tx) { + tx, err := db.db.Begin() + if err != nil { + log.Println(err) + } + return tx +} + +func (db Database) prepare(q string) (stmt *sql.Stmt) { + stmt, err := db.db.Prepare(q) + if err != nil { + log.Println(err) + } + return stmt +} + +func (db Database) query(q string, args ...interface{}) (rows *sql.Rows) { + rows, err := db.db.Query(q, args...) + if err != nil { + log.Println(err) + } + return rows +} + func init() { - database, err = sql.Open("sqlite3", "./tasks.db") + database.db, err = sql.Open("sqlite3", "./tasks.db") if err != nil { log.Println(err) } @@ -22,7 +52,7 @@ func init() { //Close function closes this database connection func Close() { - database.Close() + database.db.Close() } //GetTasks retrieves all the tasks depending on the @@ -46,10 +76,7 @@ func GetTasks(status string) types.Context { getTasksql = basicSQL + " where finish_date is not null order by priority desc, created_date asc" } - rows, err := database.Query(getTasksql) - if err != nil { - log.Println(err) - } + rows := database.query(getTasksql) defer rows.Close() for rows.Next() { err := rows.Scan(&TaskID, &TaskTitle, &TaskContent, &TaskCreated, &TaskPriority) @@ -73,10 +100,7 @@ func GetTaskByID(id int) types.Context { getTasksql := "select id, title, content, priority from task where id=?" - rows, err := database.Query(getTasksql, id) - if err != nil { - log.Println(err) - } + rows := database.query(getTasksql, id) defer rows.Close() if rows.Next() { err := rows.Scan(&task.Id, &task.Title, &task.Content, &task.Priority) @@ -92,153 +116,56 @@ func GetTaskByID(id int) types.Context { //TrashTask is used to delete the task func TrashTask(id int) error { - trashSQL, err := database.Prepare("update task set is_deleted='Y',last_modified_at=datetime() where id=?") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(trashSQL).Exec(id) - if err != nil { - log.Println("doing rollback") - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("update task set is_deleted='Y',last_modified_at=datetime() where id=?", id) return err } //CompleteTask is used to mark tasks as complete func CompleteTask(id int) error { - stmt, err := database.Prepare("update task set is_deleted='Y', finish_date=datetime(),last_modified_at=datetime() where id=?") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(stmt).Exec(id) - if err != nil { - log.Println(err) - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("update task set is_deleted='Y', finish_date=datetime(),last_modified_at=datetime() where id=?", id) return err } //DeleteAll is used to empty the trash func DeleteAll() error { - stmt, err := database.Prepare("delete from task where is_deleted='Y'") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(stmt).Exec() - if err != nil { - log.Println("doing rollback") - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("delete from task where is_deleted='Y'") return err } //RestoreTask is used to restore tasks from the Trash func RestoreTask(id int) error { - restoreSQL, err := database.Prepare("update task set is_deleted='N',last_modified_at=datetime() where id=?") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(restoreSQL).Exec(id) - if err != nil { - log.Println("doing rollback") - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("update task set is_deleted='N',last_modified_at=datetime() where id=?", id) return err } -//RestoreTask is used to restore tasks from the Trash +//RestoreTaskFromComplete is used to restore tasks from the Trash func RestoreTaskFromComplete(id int) error { - restoreSQL, err := database.Prepare("update task set finish_date=null,last_modified_at=datetime() where id=?") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(restoreSQL).Exec(id) - if err != nil { - log.Println("doing rollback") - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("update task set finish_date=null,last_modified_at=datetime() where id=?", id) return err } //DeleteTask is used to delete the task from the database func DeleteTask(id int) error { - deleteSQL, err := database.Prepare("delete from task where id = ?") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(deleteSQL).Exec(id) - if err != nil { - log.Println(err) - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("delete from task where id = ?", id) return err } //AddTask is used to add the task in the database func AddTask(title, content string, taskPriority int) error { - restoreSQL, err := database.Prepare("insert into task(title, content, priority, created_date, last_modified_at) values(?,?,?,datetime(), datetime())") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() - _, err = tx.Stmt(restoreSQL).Exec(title, content, taskPriority) - if err != nil { - log.Println(err) - tx.Rollback() - } else { - tx.Commit() - } + err := taskQuery("insert into task(title, content, priority, created_date, last_modified_at) values(?,?,?,datetime(), datetime())", title, content, taskPriority) return err } //UpdateTask is used to update the tasks in the database func UpdateTask(id int, title string, content string) error { - SQL, err := database.Prepare("update task set title=?, content=? where id=?") - if err != nil { - log.Println(err) - } - tx, err := database.Begin() + err := taskQuery("update task set title=?, content=? where id=?", title, content) + return err +} - if err != nil { - log.Println(err) - } - _, err = tx.Stmt(SQL).Exec(title, content, id) +func taskQuery(sql string, args ...interface{}) error { + SQL := database.prepare("update task set title=?, content=? where id=?") + tx := database.begin() + _, err = tx.Stmt(SQL).Exec(args...) if err != nil { log.Println(err) tx.Rollback() @@ -258,10 +185,8 @@ func SearchTask(query string) types.Context { var TaskCreated time.Time var context types.Context - rows, err := database.Query(stmt, query, query) - if err != nil { - log.Println(err) - } + rows := database.query(stmt, query, query) + for rows.Next() { err := rows.Scan(&TaskID, &TaskTitle, &TaskContent, &TaskCreated) if err != nil {