Merge pull request #13 from Skarlso/master

Encapsulate database to its own type for standard error handling.
This commit is contained in:
Suraj Patil 2016-01-31 19:31:42 +05:30
commit 58d60f1615
2 changed files with 58 additions and 140 deletions

View File

@ -7,15 +7,8 @@ import (
// AddFile is used to add the md5 of a file name which is uploaded to our application // AddFile is used to add the md5 of a file name which is uploaded to our application
// this will enable us to randomize the URL without worrying about the file names // this will enable us to randomize the URL without worrying about the file names
func AddFile(fileName, token string) error { func AddFile(fileName, token string) error {
SQL, err := database.Prepare("insert into files values(?,?)") SQL := database.prepare("insert into files values(?,?)")
if err != nil { tx := database.begin()
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(SQL).Exec(fileName, token) _, err = tx.Stmt(SQL).Exec(fileName, token)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -30,7 +23,7 @@ func AddFile(fileName, token string) error {
func GetFileName(token string) (string, error) { func GetFileName(token string) (string, error) {
sql := "select name from files where autoName=?" sql := "select name from files where autoName=?"
var fileName string var fileName string
rows, err := database.Query(sql, fileName) rows := database.query(sql, fileName)
if rows.Next() { if rows.Next() {
err := rows.Scan(&fileName) err := rows.Scan(&fileName)
if err != nil { if err != nil {

View File

@ -2,19 +2,49 @@ package db
import ( import (
"database/sql" "database/sql"
_ "github.com/mattn/go-sqlite3" //we want to use sqlite natively
md "github.com/shurcooL/github_flavored_markdown"
"github.com/thewhitetulip/Tasks/types"
"log" "log"
"strings" "strings"
"time" "time"
_ "github.com/mattn/go-sqlite3" //we want to use sqlite natively
md "github.com/shurcooL/github_flavored_markdown"
"github.com/thewhitetulip/Tasks/types"
) )
var database *sql.DB var database Database
var err error var err error
//Database encapsulates database
type Database struct {
db *sql.DB
}
func (db Database) begin() (tx *sql.Tx) {
tx, err := db.db.Begin()
if err != nil {
log.Println(err)
}
return tx
}
func (db Database) prepare(q string) (stmt *sql.Stmt) {
stmt, err := db.db.Prepare(q)
if err != nil {
log.Println(err)
}
return stmt
}
func (db Database) query(q string, args ...interface{}) (rows *sql.Rows) {
rows, err := db.db.Query(q, args...)
if err != nil {
log.Println(err)
}
return rows
}
func init() { func init() {
database, err = sql.Open("sqlite3", "./tasks.db") database.db, err = sql.Open("sqlite3", "./tasks.db")
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
@ -22,7 +52,7 @@ func init() {
//Close function closes this database connection //Close function closes this database connection
func Close() { func Close() {
database.Close() database.db.Close()
} }
//GetTasks retrieves all the tasks depending on the //GetTasks retrieves all the tasks depending on the
@ -46,10 +76,7 @@ func GetTasks(status string) types.Context {
getTasksql = basicSQL + " where finish_date is not null order by priority desc, created_date asc" getTasksql = basicSQL + " where finish_date is not null order by priority desc, created_date asc"
} }
rows, err := database.Query(getTasksql) rows := database.query(getTasksql)
if err != nil {
log.Println(err)
}
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
err := rows.Scan(&TaskID, &TaskTitle, &TaskContent, &TaskCreated, &TaskPriority) err := rows.Scan(&TaskID, &TaskTitle, &TaskContent, &TaskCreated, &TaskPriority)
@ -73,10 +100,7 @@ func GetTaskByID(id int) types.Context {
getTasksql := "select id, title, content, priority from task where id=?" getTasksql := "select id, title, content, priority from task where id=?"
rows, err := database.Query(getTasksql, id) rows := database.query(getTasksql, id)
if err != nil {
log.Println(err)
}
defer rows.Close() defer rows.Close()
if rows.Next() { if rows.Next() {
err := rows.Scan(&task.Id, &task.Title, &task.Content, &task.Priority) err := rows.Scan(&task.Id, &task.Title, &task.Content, &task.Priority)
@ -92,153 +116,56 @@ func GetTaskByID(id int) types.Context {
//TrashTask is used to delete the task //TrashTask is used to delete the task
func TrashTask(id int) error { func TrashTask(id int) error {
trashSQL, err := database.Prepare("update task set is_deleted='Y',last_modified_at=datetime() where id=?") err := taskQuery("update task set is_deleted='Y',last_modified_at=datetime() where id=?", id)
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(trashSQL).Exec(id)
if err != nil {
log.Println("doing rollback")
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//CompleteTask is used to mark tasks as complete //CompleteTask is used to mark tasks as complete
func CompleteTask(id int) error { func CompleteTask(id int) error {
stmt, err := database.Prepare("update task set is_deleted='Y', finish_date=datetime(),last_modified_at=datetime() where id=?") err := taskQuery("update task set is_deleted='Y', finish_date=datetime(),last_modified_at=datetime() where id=?", id)
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(stmt).Exec(id)
if err != nil {
log.Println(err)
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//DeleteAll is used to empty the trash //DeleteAll is used to empty the trash
func DeleteAll() error { func DeleteAll() error {
stmt, err := database.Prepare("delete from task where is_deleted='Y'") err := taskQuery("delete from task where is_deleted='Y'")
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(stmt).Exec()
if err != nil {
log.Println("doing rollback")
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//RestoreTask is used to restore tasks from the Trash //RestoreTask is used to restore tasks from the Trash
func RestoreTask(id int) error { func RestoreTask(id int) error {
restoreSQL, err := database.Prepare("update task set is_deleted='N',last_modified_at=datetime() where id=?") err := taskQuery("update task set is_deleted='N',last_modified_at=datetime() where id=?", id)
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(restoreSQL).Exec(id)
if err != nil {
log.Println("doing rollback")
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//RestoreTask is used to restore tasks from the Trash //RestoreTaskFromComplete is used to restore tasks from the Trash
func RestoreTaskFromComplete(id int) error { func RestoreTaskFromComplete(id int) error {
restoreSQL, err := database.Prepare("update task set finish_date=null,last_modified_at=datetime() where id=?") err := taskQuery("update task set finish_date=null,last_modified_at=datetime() where id=?", id)
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(restoreSQL).Exec(id)
if err != nil {
log.Println("doing rollback")
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//DeleteTask is used to delete the task from the database //DeleteTask is used to delete the task from the database
func DeleteTask(id int) error { func DeleteTask(id int) error {
deleteSQL, err := database.Prepare("delete from task where id = ?") err := taskQuery("delete from task where id = ?", id)
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
if err != nil {
log.Println(err)
}
_, err = tx.Stmt(deleteSQL).Exec(id)
if err != nil {
log.Println(err)
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//AddTask is used to add the task in the database //AddTask is used to add the task in the database
func AddTask(title, content string, taskPriority int) error { func AddTask(title, content string, taskPriority int) error {
restoreSQL, err := database.Prepare("insert into task(title, content, priority, created_date, last_modified_at) values(?,?,?,datetime(), datetime())") err := taskQuery("insert into task(title, content, priority, created_date, last_modified_at) values(?,?,?,datetime(), datetime())", title, content, taskPriority)
if err != nil {
log.Println(err)
}
tx, err := database.Begin()
_, err = tx.Stmt(restoreSQL).Exec(title, content, taskPriority)
if err != nil {
log.Println(err)
tx.Rollback()
} else {
tx.Commit()
}
return err return err
} }
//UpdateTask is used to update the tasks in the database //UpdateTask is used to update the tasks in the database
func UpdateTask(id int, title string, content string) error { func UpdateTask(id int, title string, content string) error {
SQL, err := database.Prepare("update task set title=?, content=? where id=?") err := taskQuery("update task set title=?, content=? where id=?", title, content)
if err != nil { return err
log.Println(err) }
}
tx, err := database.Begin()
if err != nil { func taskQuery(sql string, args ...interface{}) error {
log.Println(err) SQL := database.prepare("update task set title=?, content=? where id=?")
} tx := database.begin()
_, err = tx.Stmt(SQL).Exec(title, content, id) _, err = tx.Stmt(SQL).Exec(args...)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
tx.Rollback() tx.Rollback()
@ -258,10 +185,8 @@ func SearchTask(query string) types.Context {
var TaskCreated time.Time var TaskCreated time.Time
var context types.Context var context types.Context
rows, err := database.Query(stmt, query, query) rows := database.query(stmt, query, query)
if err != nil {
log.Println(err)
}
for rows.Next() { for rows.Next() {
err := rows.Scan(&TaskID, &TaskTitle, &TaskContent, &TaskCreated) err := rows.Scan(&TaskID, &TaskTitle, &TaskContent, &TaskCreated)
if err != nil { if err != nil {