Major db/API refactor, transaction as default

This commit is contained in:
Imbus 2024-04-14 09:21:34 +02:00
parent fe9d5f74bb
commit c7e036ec04
4 changed files with 40 additions and 43 deletions

View file

@ -4,7 +4,6 @@ import (
"embed" "embed"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"path/filepath" "path/filepath"
"ttime/internal/types" "ttime/internal/types"
@ -24,8 +23,6 @@ type Database interface {
GetUserId(username string) (int, error) GetUserId(username string) (int, error)
AddProject(name string, description string, username string) error AddProject(name string, description string, username string) error
DeleteProject(name string, username string) error DeleteProject(name string, username string) error
Migrate() error
MigrateSampleData() error
GetProjectId(projectname string) (int, error) GetProjectId(projectname string) (int, error)
AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error
AddUserToProject(username string, projectname string, role string) error AddUserToProject(username string, projectname string, role string) error
@ -55,7 +52,7 @@ type Database interface {
// This struct is a wrapper type that holds the database connection // This struct is a wrapper type that holds the database connection
// Internally DB holds a connection pool, so it's safe for concurrent use // Internally DB holds a connection pool, so it's safe for concurrent use
type Db struct { type Db struct {
*sqlx.DB *sqlx.Tx
} }
type UserProjectMember struct { type UserProjectMember struct {
@ -109,7 +106,7 @@ const reportStatistics = `SELECT SUM(development_time) AS total_development_time
GROUP BY user_id, project_id` GROUP BY user_id, project_id`
// DbConnect connects to the database // DbConnect connects to the database
func DbConnect(dbpath string) Database { func DbConnect(dbpath string) sqlx.DB {
// Open the database // Open the database
db, err := sqlx.Connect("sqlite", dbpath) db, err := sqlx.Connect("sqlite", dbpath)
if err != nil { if err != nil {
@ -122,7 +119,7 @@ func DbConnect(dbpath string) Database {
panic(err) panic(err)
} }
return &Db{db} return *db
} }
func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) { func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) {
@ -244,25 +241,15 @@ func (d *Db) GetProjectId(projectname string) (int, error) {
// Creates a new project in the database, associated with a user // Creates a new project in the database, associated with a user
func (d *Db) AddProject(name string, description string, username string) error { func (d *Db) AddProject(name string, description string, username string) error {
tx := d.MustBegin()
// Insert the project into the database // Insert the project into the database
_, err := tx.Exec(projectInsert, name, description, username) _, err := d.Exec(projectInsert, name, description, username)
if err != nil { if err != nil {
if err := tx.Rollback(); err != nil {
return err
}
return err return err
} }
// Add creator to project as project manager // Add creator to project as project manager
_, err = tx.Exec(addUserToProject, username, name, "project_manager") _, err = d.Exec(addUserToProject, username, name, "project_manager")
if err != nil { if err != nil {
if err := tx.Rollback(); err != nil {
return err
}
return err
}
if err := tx.Commit(); err != nil {
return err return err
} }
@ -270,16 +257,7 @@ func (d *Db) AddProject(name string, description string, username string) error
} }
func (d *Db) DeleteProject(projectID string, username string) error { func (d *Db) DeleteProject(projectID string, username string) error {
tx := d.MustBegin() _, err := d.Exec(deleteProject, projectID, username)
_, err := tx.Exec(deleteProject, projectID, username)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("error rolling back transaction: %v, delete error: %v", rollbackErr, err)
}
panic(err)
}
return err return err
} }
@ -503,7 +481,7 @@ func (d *Db) IsSiteAdmin(username string) (bool, error) {
// Reads a directory of migration files and applies them to the database. // Reads a directory of migration files and applies them to the database.
// This will eventually be used on an embedded directory // This will eventually be used on an embedded directory
func (d *Db) Migrate() error { func Migrate(db sqlx.DB) error {
// Read the embedded scripts directory // Read the embedded scripts directory
files, err := scripts.ReadDir("migrations") files, err := scripts.ReadDir("migrations")
if err != nil { if err != nil {
@ -515,7 +493,7 @@ func (d *Db) Migrate() error {
return nil return nil
} }
tr := d.MustBegin() tr := db.MustBegin()
// Iterate over each SQL file and execute it // Iterate over each SQL file and execute it
for _, file := range files { for _, file := range files {
@ -601,7 +579,7 @@ func (d *Db) UpdateWeeklyReport(projectName string, userName string, week int, d
} }
// MigrateSampleData applies sample data to the database. // MigrateSampleData applies sample data to the database.
func (d *Db) MigrateSampleData() error { func MigrateSampleData(db sqlx.DB) error {
// Insert sample data // Insert sample data
files, err := sampleData.ReadDir("sample_data") files, err := sampleData.ReadDir("sample_data")
if err != nil { if err != nil {
@ -611,7 +589,7 @@ func (d *Db) MigrateSampleData() error {
if len(files) == 0 { if len(files) == 0 {
println("No sample data files found") println("No sample data files found")
} }
tr := d.MustBegin() tr := db.MustBegin()
// Iterate over each SQL file and execute it // Iterate over each SQL file and execute it
for _, file := range files { for _, file := range files {
@ -648,7 +626,7 @@ func (d *Db) GetProjectTimes(projectName string) (map[string]int, error) {
WHERE projects.name = ? WHERE projects.name = ?
` `
rows, err := d.DB.Query(query, projectName) rows, err := d.Query(query, projectName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -9,11 +9,13 @@ import (
// setupState initializes a database instance with necessary setup for testing // setupState initializes a database instance with necessary setup for testing
func setupState() (Database, error) { func setupState() (Database, error) {
db := DbConnect(":memory:") db := DbConnect(":memory:")
err := db.Migrate() err := Migrate(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return db, nil
db_iface := Db{db.MustBegin()}
return &db_iface, nil
} }
// This is a more advanced setup that includes more data in the database. // This is a more advanced setup that includes more data in the database.
@ -1078,7 +1080,7 @@ func TestDeleteReport(t *testing.T) {
} }
// Remove report // Remove report
err = db.DeleteReport(report.ReportId,) err = db.DeleteReport(report.ReportId)
if err != nil { if err != nil {
t.Error("RemoveReport failed:", err) t.Error("RemoveReport failed:", err)
} }
@ -1088,5 +1090,5 @@ func TestDeleteReport(t *testing.T) {
if err == nil { if err == nil {
t.Error("RemoveReport failed: report not removed") t.Error("RemoveReport failed: report not removed")
} }
} }

View file

@ -1,11 +1,28 @@
package database package database
import "github.com/gofiber/fiber/v2" import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/jmoiron/sqlx"
)
// Simple middleware that provides a shared database pool as a local key "db" // Simple middleware that provides a transaction as a local key "db"
func DbMiddleware(db *Database) func(c *fiber.Ctx) error { func DbMiddleware(db *sqlx.DB) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
c.Locals("db", db) tx := db.MustBegin()
defer func() {
if err := tx.Commit(); err != nil {
if err = tx.Rollback(); err != nil {
log.Error("Failed to rollback transaction: ", err)
}
return
}
}()
var db_iface Database = &Db{tx}
c.Locals("db", &db_iface)
return c.Next() return c.Next()
} }
} }

View file

@ -59,13 +59,13 @@ func main() {
db := database.DbConnect(conf.DbPath) db := database.DbConnect(conf.DbPath)
// Migrate the database // Migrate the database
if err = db.Migrate(); err != nil { if err = database.Migrate(db); err != nil {
fmt.Println("Error migrating database: ", err) fmt.Println("Error migrating database: ", err)
os.Exit(1) os.Exit(1)
} }
// Migrate sample data, should not be used in production // Migrate sample data, should not be used in production
if err = db.MigrateSampleData(); err != nil { if err = database.MigrateSampleData(db); err != nil {
fmt.Println("Error migrating sample data: ", err) fmt.Println("Error migrating sample data: ", err)
os.Exit(1) os.Exit(1)
} }