From c7e036ec0497cd1093aae89afc316eb246fda200 Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Sun, 14 Apr 2024 09:21:34 +0200 Subject: [PATCH] Major db/API refactor, transaction as default --- backend/internal/database/db.go | 44 +++++++------------------ backend/internal/database/db_test.go | 10 +++--- backend/internal/database/middleware.go | 25 +++++++++++--- backend/main.go | 4 +-- 4 files changed, 40 insertions(+), 43 deletions(-) diff --git a/backend/internal/database/db.go b/backend/internal/database/db.go index e0f97c3..d444b85 100644 --- a/backend/internal/database/db.go +++ b/backend/internal/database/db.go @@ -4,7 +4,6 @@ import ( "embed" "encoding/json" "errors" - "fmt" "path/filepath" "ttime/internal/types" @@ -24,8 +23,6 @@ type Database interface { GetUserId(username string) (int, error) AddProject(name string, description string, username string) error DeleteProject(name string, username string) error - Migrate() error - MigrateSampleData() error GetProjectId(projectname string) (int, error) AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error AddUserToProject(username string, projectname string, role string) error @@ -55,7 +52,7 @@ type Database interface { // This struct is a wrapper type that holds the database connection // Internally DB holds a connection pool, so it's safe for concurrent use type Db struct { - *sqlx.DB + *sqlx.Tx } type UserProjectMember struct { @@ -109,7 +106,7 @@ const reportStatistics = `SELECT SUM(development_time) AS total_development_time GROUP BY user_id, project_id` // DbConnect connects to the database -func DbConnect(dbpath string) Database { +func DbConnect(dbpath string) sqlx.DB { // Open the database db, err := sqlx.Connect("sqlite", dbpath) if err != nil { @@ -122,7 +119,7 @@ func DbConnect(dbpath string) Database { panic(err) } - return &Db{db} + return *db } func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) { @@ -244,25 +241,15 @@ func (d *Db) GetProjectId(projectname string) (int, error) { // Creates a new project in the database, associated with a user func (d *Db) AddProject(name string, description string, username string) error { - tx := d.MustBegin() // Insert the project into the database - _, err := tx.Exec(projectInsert, name, description, username) + _, err := d.Exec(projectInsert, name, description, username) if err != nil { - if err := tx.Rollback(); err != nil { - return err - } return err } // Add creator to project as project manager - _, err = tx.Exec(addUserToProject, username, name, "project_manager") + _, err = d.Exec(addUserToProject, username, name, "project_manager") if err != nil { - if err := tx.Rollback(); err != nil { - return err - } - return err - } - if err := tx.Commit(); err != nil { return err } @@ -270,16 +257,7 @@ func (d *Db) AddProject(name string, description string, username string) error } func (d *Db) DeleteProject(projectID string, username string) error { - tx := d.MustBegin() - - _, err := tx.Exec(deleteProject, projectID, username) - - if err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("error rolling back transaction: %v, delete error: %v", rollbackErr, err) - } - panic(err) - } + _, err := d.Exec(deleteProject, projectID, username) return err } @@ -503,7 +481,7 @@ func (d *Db) IsSiteAdmin(username string) (bool, error) { // Reads a directory of migration files and applies them to the database. // This will eventually be used on an embedded directory -func (d *Db) Migrate() error { +func Migrate(db sqlx.DB) error { // Read the embedded scripts directory files, err := scripts.ReadDir("migrations") if err != nil { @@ -515,7 +493,7 @@ func (d *Db) Migrate() error { return nil } - tr := d.MustBegin() + tr := db.MustBegin() // Iterate over each SQL file and execute it for _, file := range files { @@ -601,7 +579,7 @@ func (d *Db) UpdateWeeklyReport(projectName string, userName string, week int, d } // MigrateSampleData applies sample data to the database. -func (d *Db) MigrateSampleData() error { +func MigrateSampleData(db sqlx.DB) error { // Insert sample data files, err := sampleData.ReadDir("sample_data") if err != nil { @@ -611,7 +589,7 @@ func (d *Db) MigrateSampleData() error { if len(files) == 0 { println("No sample data files found") } - tr := d.MustBegin() + tr := db.MustBegin() // Iterate over each SQL file and execute it for _, file := range files { @@ -648,7 +626,7 @@ func (d *Db) GetProjectTimes(projectName string) (map[string]int, error) { WHERE projects.name = ? ` - rows, err := d.DB.Query(query, projectName) + rows, err := d.Query(query, projectName) if err != nil { return nil, err } diff --git a/backend/internal/database/db_test.go b/backend/internal/database/db_test.go index b5a598c..f24175a 100644 --- a/backend/internal/database/db_test.go +++ b/backend/internal/database/db_test.go @@ -9,11 +9,13 @@ import ( // setupState initializes a database instance with necessary setup for testing func setupState() (Database, error) { db := DbConnect(":memory:") - err := db.Migrate() + err := Migrate(db) if err != nil { return nil, err } - return db, nil + + db_iface := Db{db.MustBegin()} + return &db_iface, nil } // This is a more advanced setup that includes more data in the database. @@ -1078,7 +1080,7 @@ func TestDeleteReport(t *testing.T) { } // Remove report - err = db.DeleteReport(report.ReportId,) + err = db.DeleteReport(report.ReportId) if err != nil { t.Error("RemoveReport failed:", err) } @@ -1088,5 +1090,5 @@ func TestDeleteReport(t *testing.T) { if err == nil { t.Error("RemoveReport failed: report not removed") } - + } diff --git a/backend/internal/database/middleware.go b/backend/internal/database/middleware.go index 69fa3a2..b73a42f 100644 --- a/backend/internal/database/middleware.go +++ b/backend/internal/database/middleware.go @@ -1,11 +1,28 @@ package database -import "github.com/gofiber/fiber/v2" +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/log" + "github.com/jmoiron/sqlx" +) -// Simple middleware that provides a shared database pool as a local key "db" -func DbMiddleware(db *Database) func(c *fiber.Ctx) error { +// Simple middleware that provides a transaction as a local key "db" +func DbMiddleware(db *sqlx.DB) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - c.Locals("db", db) + tx := db.MustBegin() + + defer func() { + if err := tx.Commit(); err != nil { + if err = tx.Rollback(); err != nil { + log.Error("Failed to rollback transaction: ", err) + } + return + } + }() + + var db_iface Database = &Db{tx} + + c.Locals("db", &db_iface) return c.Next() } } diff --git a/backend/main.go b/backend/main.go index accae7a..8a3466a 100644 --- a/backend/main.go +++ b/backend/main.go @@ -59,13 +59,13 @@ func main() { db := database.DbConnect(conf.DbPath) // Migrate the database - if err = db.Migrate(); err != nil { + if err = database.Migrate(db); err != nil { fmt.Println("Error migrating database: ", err) os.Exit(1) } // Migrate sample data, should not be used in production - if err = db.MigrateSampleData(); err != nil { + if err = database.MigrateSampleData(db); err != nil { fmt.Println("Error migrating sample data: ", err) os.Exit(1) }