Major db/API refactor, transaction as default
This commit is contained in:
parent
fe9d5f74bb
commit
c7e036ec04
4 changed files with 40 additions and 43 deletions
|
@ -4,7 +4,6 @@ import (
|
|||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"ttime/internal/types"
|
||||
|
||||
|
@ -24,8 +23,6 @@ type Database interface {
|
|||
GetUserId(username string) (int, error)
|
||||
AddProject(name string, description string, username string) error
|
||||
DeleteProject(name string, username string) error
|
||||
Migrate() error
|
||||
MigrateSampleData() error
|
||||
GetProjectId(projectname string) (int, error)
|
||||
AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error
|
||||
AddUserToProject(username string, projectname string, role string) error
|
||||
|
@ -55,7 +52,7 @@ type Database interface {
|
|||
// This struct is a wrapper type that holds the database connection
|
||||
// Internally DB holds a connection pool, so it's safe for concurrent use
|
||||
type Db struct {
|
||||
*sqlx.DB
|
||||
*sqlx.Tx
|
||||
}
|
||||
|
||||
type UserProjectMember struct {
|
||||
|
@ -109,7 +106,7 @@ const reportStatistics = `SELECT SUM(development_time) AS total_development_time
|
|||
GROUP BY user_id, project_id`
|
||||
|
||||
// DbConnect connects to the database
|
||||
func DbConnect(dbpath string) Database {
|
||||
func DbConnect(dbpath string) sqlx.DB {
|
||||
// Open the database
|
||||
db, err := sqlx.Connect("sqlite", dbpath)
|
||||
if err != nil {
|
||||
|
@ -122,7 +119,7 @@ func DbConnect(dbpath string) Database {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
return &Db{db}
|
||||
return *db
|
||||
}
|
||||
|
||||
func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) {
|
||||
|
@ -244,25 +241,15 @@ func (d *Db) GetProjectId(projectname string) (int, error) {
|
|||
|
||||
// Creates a new project in the database, associated with a user
|
||||
func (d *Db) AddProject(name string, description string, username string) error {
|
||||
tx := d.MustBegin()
|
||||
// Insert the project into the database
|
||||
_, err := tx.Exec(projectInsert, name, description, username)
|
||||
_, err := d.Exec(projectInsert, name, description, username)
|
||||
if err != nil {
|
||||
if err := tx.Rollback(); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Add creator to project as project manager
|
||||
_, err = tx.Exec(addUserToProject, username, name, "project_manager")
|
||||
_, err = d.Exec(addUserToProject, username, name, "project_manager")
|
||||
if err != nil {
|
||||
if err := tx.Rollback(); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -270,16 +257,7 @@ func (d *Db) AddProject(name string, description string, username string) error
|
|||
}
|
||||
|
||||
func (d *Db) DeleteProject(projectID string, username string) error {
|
||||
tx := d.MustBegin()
|
||||
|
||||
_, err := tx.Exec(deleteProject, projectID, username)
|
||||
|
||||
if err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("error rolling back transaction: %v, delete error: %v", rollbackErr, err)
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
_, err := d.Exec(deleteProject, projectID, username)
|
||||
|
||||
return err
|
||||
}
|
||||
|
@ -503,7 +481,7 @@ func (d *Db) IsSiteAdmin(username string) (bool, error) {
|
|||
|
||||
// Reads a directory of migration files and applies them to the database.
|
||||
// This will eventually be used on an embedded directory
|
||||
func (d *Db) Migrate() error {
|
||||
func Migrate(db sqlx.DB) error {
|
||||
// Read the embedded scripts directory
|
||||
files, err := scripts.ReadDir("migrations")
|
||||
if err != nil {
|
||||
|
@ -515,7 +493,7 @@ func (d *Db) Migrate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
tr := d.MustBegin()
|
||||
tr := db.MustBegin()
|
||||
|
||||
// Iterate over each SQL file and execute it
|
||||
for _, file := range files {
|
||||
|
@ -601,7 +579,7 @@ func (d *Db) UpdateWeeklyReport(projectName string, userName string, week int, d
|
|||
}
|
||||
|
||||
// MigrateSampleData applies sample data to the database.
|
||||
func (d *Db) MigrateSampleData() error {
|
||||
func MigrateSampleData(db sqlx.DB) error {
|
||||
// Insert sample data
|
||||
files, err := sampleData.ReadDir("sample_data")
|
||||
if err != nil {
|
||||
|
@ -611,7 +589,7 @@ func (d *Db) MigrateSampleData() error {
|
|||
if len(files) == 0 {
|
||||
println("No sample data files found")
|
||||
}
|
||||
tr := d.MustBegin()
|
||||
tr := db.MustBegin()
|
||||
|
||||
// Iterate over each SQL file and execute it
|
||||
for _, file := range files {
|
||||
|
@ -648,7 +626,7 @@ func (d *Db) GetProjectTimes(projectName string) (map[string]int, error) {
|
|||
WHERE projects.name = ?
|
||||
`
|
||||
|
||||
rows, err := d.DB.Query(query, projectName)
|
||||
rows, err := d.Query(query, projectName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -9,11 +9,13 @@ import (
|
|||
// setupState initializes a database instance with necessary setup for testing
|
||||
func setupState() (Database, error) {
|
||||
db := DbConnect(":memory:")
|
||||
err := db.Migrate()
|
||||
err := Migrate(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
|
||||
db_iface := Db{db.MustBegin()}
|
||||
return &db_iface, nil
|
||||
}
|
||||
|
||||
// This is a more advanced setup that includes more data in the database.
|
||||
|
@ -1078,7 +1080,7 @@ func TestDeleteReport(t *testing.T) {
|
|||
}
|
||||
|
||||
// Remove report
|
||||
err = db.DeleteReport(report.ReportId,)
|
||||
err = db.DeleteReport(report.ReportId)
|
||||
if err != nil {
|
||||
t.Error("RemoveReport failed:", err)
|
||||
}
|
||||
|
@ -1088,5 +1090,5 @@ func TestDeleteReport(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Error("RemoveReport failed: report not removed")
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1,11 +1,28 @@
|
|||
package database
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/log"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// Simple middleware that provides a shared database pool as a local key "db"
|
||||
func DbMiddleware(db *Database) func(c *fiber.Ctx) error {
|
||||
// Simple middleware that provides a transaction as a local key "db"
|
||||
func DbMiddleware(db *sqlx.DB) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
c.Locals("db", db)
|
||||
tx := db.MustBegin()
|
||||
|
||||
defer func() {
|
||||
if err := tx.Commit(); err != nil {
|
||||
if err = tx.Rollback(); err != nil {
|
||||
log.Error("Failed to rollback transaction: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
var db_iface Database = &Db{tx}
|
||||
|
||||
c.Locals("db", &db_iface)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,13 +59,13 @@ func main() {
|
|||
db := database.DbConnect(conf.DbPath)
|
||||
|
||||
// Migrate the database
|
||||
if err = db.Migrate(); err != nil {
|
||||
if err = database.Migrate(db); err != nil {
|
||||
fmt.Println("Error migrating database: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Migrate sample data, should not be used in production
|
||||
if err = db.MigrateSampleData(); err != nil {
|
||||
if err = database.MigrateSampleData(db); err != nil {
|
||||
fmt.Println("Error migrating sample data: ", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue