Major db/API refactor, transaction as default

This commit is contained in:
Imbus 2024-04-14 09:21:34 +02:00
parent fe9d5f74bb
commit c7e036ec04
4 changed files with 40 additions and 43 deletions

View file

@ -4,7 +4,6 @@ import (
"embed"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"ttime/internal/types"
@ -24,8 +23,6 @@ type Database interface {
GetUserId(username string) (int, error)
AddProject(name string, description string, username string) error
DeleteProject(name string, username string) error
Migrate() error
MigrateSampleData() error
GetProjectId(projectname string) (int, error)
AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error
AddUserToProject(username string, projectname string, role string) error
@ -55,7 +52,7 @@ type Database interface {
// This struct is a wrapper type that holds the database connection
// Internally DB holds a connection pool, so it's safe for concurrent use
type Db struct {
*sqlx.DB
*sqlx.Tx
}
type UserProjectMember struct {
@ -109,7 +106,7 @@ const reportStatistics = `SELECT SUM(development_time) AS total_development_time
GROUP BY user_id, project_id`
// DbConnect connects to the database
func DbConnect(dbpath string) Database {
func DbConnect(dbpath string) sqlx.DB {
// Open the database
db, err := sqlx.Connect("sqlite", dbpath)
if err != nil {
@ -122,7 +119,7 @@ func DbConnect(dbpath string) Database {
panic(err)
}
return &Db{db}
return *db
}
func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) {
@ -244,25 +241,15 @@ func (d *Db) GetProjectId(projectname string) (int, error) {
// Creates a new project in the database, associated with a user
func (d *Db) AddProject(name string, description string, username string) error {
tx := d.MustBegin()
// Insert the project into the database
_, err := tx.Exec(projectInsert, name, description, username)
_, err := d.Exec(projectInsert, name, description, username)
if err != nil {
if err := tx.Rollback(); err != nil {
return err
}
return err
}
// Add creator to project as project manager
_, err = tx.Exec(addUserToProject, username, name, "project_manager")
_, err = d.Exec(addUserToProject, username, name, "project_manager")
if err != nil {
if err := tx.Rollback(); err != nil {
return err
}
return err
}
if err := tx.Commit(); err != nil {
return err
}
@ -270,16 +257,7 @@ func (d *Db) AddProject(name string, description string, username string) error
}
func (d *Db) DeleteProject(projectID string, username string) error {
tx := d.MustBegin()
_, err := tx.Exec(deleteProject, projectID, username)
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("error rolling back transaction: %v, delete error: %v", rollbackErr, err)
}
panic(err)
}
_, err := d.Exec(deleteProject, projectID, username)
return err
}
@ -503,7 +481,7 @@ func (d *Db) IsSiteAdmin(username string) (bool, error) {
// Reads a directory of migration files and applies them to the database.
// This will eventually be used on an embedded directory
func (d *Db) Migrate() error {
func Migrate(db sqlx.DB) error {
// Read the embedded scripts directory
files, err := scripts.ReadDir("migrations")
if err != nil {
@ -515,7 +493,7 @@ func (d *Db) Migrate() error {
return nil
}
tr := d.MustBegin()
tr := db.MustBegin()
// Iterate over each SQL file and execute it
for _, file := range files {
@ -601,7 +579,7 @@ func (d *Db) UpdateWeeklyReport(projectName string, userName string, week int, d
}
// MigrateSampleData applies sample data to the database.
func (d *Db) MigrateSampleData() error {
func MigrateSampleData(db sqlx.DB) error {
// Insert sample data
files, err := sampleData.ReadDir("sample_data")
if err != nil {
@ -611,7 +589,7 @@ func (d *Db) MigrateSampleData() error {
if len(files) == 0 {
println("No sample data files found")
}
tr := d.MustBegin()
tr := db.MustBegin()
// Iterate over each SQL file and execute it
for _, file := range files {
@ -648,7 +626,7 @@ func (d *Db) GetProjectTimes(projectName string) (map[string]int, error) {
WHERE projects.name = ?
`
rows, err := d.DB.Query(query, projectName)
rows, err := d.Query(query, projectName)
if err != nil {
return nil, err
}

View file

@ -9,11 +9,13 @@ import (
// setupState initializes a database instance with necessary setup for testing
func setupState() (Database, error) {
db := DbConnect(":memory:")
err := db.Migrate()
err := Migrate(db)
if err != nil {
return nil, err
}
return db, nil
db_iface := Db{db.MustBegin()}
return &db_iface, nil
}
// This is a more advanced setup that includes more data in the database.
@ -1078,7 +1080,7 @@ func TestDeleteReport(t *testing.T) {
}
// Remove report
err = db.DeleteReport(report.ReportId,)
err = db.DeleteReport(report.ReportId)
if err != nil {
t.Error("RemoveReport failed:", err)
}

View file

@ -1,11 +1,28 @@
package database
import "github.com/gofiber/fiber/v2"
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/jmoiron/sqlx"
)
// Simple middleware that provides a shared database pool as a local key "db"
func DbMiddleware(db *Database) func(c *fiber.Ctx) error {
// Simple middleware that provides a transaction as a local key "db"
func DbMiddleware(db *sqlx.DB) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
c.Locals("db", db)
tx := db.MustBegin()
defer func() {
if err := tx.Commit(); err != nil {
if err = tx.Rollback(); err != nil {
log.Error("Failed to rollback transaction: ", err)
}
return
}
}()
var db_iface Database = &Db{tx}
c.Locals("db", &db_iface)
return c.Next()
}
}

View file

@ -59,13 +59,13 @@ func main() {
db := database.DbConnect(conf.DbPath)
// Migrate the database
if err = db.Migrate(); err != nil {
if err = database.Migrate(db); err != nil {
fmt.Println("Error migrating database: ", err)
os.Exit(1)
}
// Migrate sample data, should not be used in production
if err = db.MigrateSampleData(); err != nil {
if err = database.MigrateSampleData(db); err != nil {
fmt.Println("Error migrating sample data: ", err)
os.Exit(1)
}