Major db/API refactor, transaction as default
This commit is contained in:
parent
fe9d5f74bb
commit
c7e036ec04
4 changed files with 40 additions and 43 deletions
|
@ -4,7 +4,6 @@ import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"ttime/internal/types"
|
"ttime/internal/types"
|
||||||
|
|
||||||
|
@ -24,8 +23,6 @@ type Database interface {
|
||||||
GetUserId(username string) (int, error)
|
GetUserId(username string) (int, error)
|
||||||
AddProject(name string, description string, username string) error
|
AddProject(name string, description string, username string) error
|
||||||
DeleteProject(name string, username string) error
|
DeleteProject(name string, username string) error
|
||||||
Migrate() error
|
|
||||||
MigrateSampleData() error
|
|
||||||
GetProjectId(projectname string) (int, error)
|
GetProjectId(projectname string) (int, error)
|
||||||
AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error
|
AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error
|
||||||
AddUserToProject(username string, projectname string, role string) error
|
AddUserToProject(username string, projectname string, role string) error
|
||||||
|
@ -55,7 +52,7 @@ type Database interface {
|
||||||
// This struct is a wrapper type that holds the database connection
|
// This struct is a wrapper type that holds the database connection
|
||||||
// Internally DB holds a connection pool, so it's safe for concurrent use
|
// Internally DB holds a connection pool, so it's safe for concurrent use
|
||||||
type Db struct {
|
type Db struct {
|
||||||
*sqlx.DB
|
*sqlx.Tx
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserProjectMember struct {
|
type UserProjectMember struct {
|
||||||
|
@ -109,7 +106,7 @@ const reportStatistics = `SELECT SUM(development_time) AS total_development_time
|
||||||
GROUP BY user_id, project_id`
|
GROUP BY user_id, project_id`
|
||||||
|
|
||||||
// DbConnect connects to the database
|
// DbConnect connects to the database
|
||||||
func DbConnect(dbpath string) Database {
|
func DbConnect(dbpath string) sqlx.DB {
|
||||||
// Open the database
|
// Open the database
|
||||||
db, err := sqlx.Connect("sqlite", dbpath)
|
db, err := sqlx.Connect("sqlite", dbpath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -122,7 +119,7 @@ func DbConnect(dbpath string) Database {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Db{db}
|
return *db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) {
|
func (d *Db) ReportStatistics(username string, projectName string) (*types.Statistics, error) {
|
||||||
|
@ -244,25 +241,15 @@ func (d *Db) GetProjectId(projectname string) (int, error) {
|
||||||
|
|
||||||
// Creates a new project in the database, associated with a user
|
// Creates a new project in the database, associated with a user
|
||||||
func (d *Db) AddProject(name string, description string, username string) error {
|
func (d *Db) AddProject(name string, description string, username string) error {
|
||||||
tx := d.MustBegin()
|
|
||||||
// Insert the project into the database
|
// Insert the project into the database
|
||||||
_, err := tx.Exec(projectInsert, name, description, username)
|
_, err := d.Exec(projectInsert, name, description, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := tx.Rollback(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add creator to project as project manager
|
// Add creator to project as project manager
|
||||||
_, err = tx.Exec(addUserToProject, username, name, "project_manager")
|
_, err = d.Exec(addUserToProject, username, name, "project_manager")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := tx.Rollback(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,16 +257,7 @@ func (d *Db) AddProject(name string, description string, username string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Db) DeleteProject(projectID string, username string) error {
|
func (d *Db) DeleteProject(projectID string, username string) error {
|
||||||
tx := d.MustBegin()
|
_, err := d.Exec(deleteProject, projectID, username)
|
||||||
|
|
||||||
_, err := tx.Exec(deleteProject, projectID, username)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
||||||
return fmt.Errorf("error rolling back transaction: %v, delete error: %v", rollbackErr, err)
|
|
||||||
}
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -503,7 +481,7 @@ func (d *Db) IsSiteAdmin(username string) (bool, error) {
|
||||||
|
|
||||||
// Reads a directory of migration files and applies them to the database.
|
// Reads a directory of migration files and applies them to the database.
|
||||||
// This will eventually be used on an embedded directory
|
// This will eventually be used on an embedded directory
|
||||||
func (d *Db) Migrate() error {
|
func Migrate(db sqlx.DB) error {
|
||||||
// Read the embedded scripts directory
|
// Read the embedded scripts directory
|
||||||
files, err := scripts.ReadDir("migrations")
|
files, err := scripts.ReadDir("migrations")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -515,7 +493,7 @@ func (d *Db) Migrate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tr := d.MustBegin()
|
tr := db.MustBegin()
|
||||||
|
|
||||||
// Iterate over each SQL file and execute it
|
// Iterate over each SQL file and execute it
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
|
@ -601,7 +579,7 @@ func (d *Db) UpdateWeeklyReport(projectName string, userName string, week int, d
|
||||||
}
|
}
|
||||||
|
|
||||||
// MigrateSampleData applies sample data to the database.
|
// MigrateSampleData applies sample data to the database.
|
||||||
func (d *Db) MigrateSampleData() error {
|
func MigrateSampleData(db sqlx.DB) error {
|
||||||
// Insert sample data
|
// Insert sample data
|
||||||
files, err := sampleData.ReadDir("sample_data")
|
files, err := sampleData.ReadDir("sample_data")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -611,7 +589,7 @@ func (d *Db) MigrateSampleData() error {
|
||||||
if len(files) == 0 {
|
if len(files) == 0 {
|
||||||
println("No sample data files found")
|
println("No sample data files found")
|
||||||
}
|
}
|
||||||
tr := d.MustBegin()
|
tr := db.MustBegin()
|
||||||
|
|
||||||
// Iterate over each SQL file and execute it
|
// Iterate over each SQL file and execute it
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
|
@ -648,7 +626,7 @@ func (d *Db) GetProjectTimes(projectName string) (map[string]int, error) {
|
||||||
WHERE projects.name = ?
|
WHERE projects.name = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := d.DB.Query(query, projectName)
|
rows, err := d.Query(query, projectName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,11 +9,13 @@ import (
|
||||||
// setupState initializes a database instance with necessary setup for testing
|
// setupState initializes a database instance with necessary setup for testing
|
||||||
func setupState() (Database, error) {
|
func setupState() (Database, error) {
|
||||||
db := DbConnect(":memory:")
|
db := DbConnect(":memory:")
|
||||||
err := db.Migrate()
|
err := Migrate(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return db, nil
|
|
||||||
|
db_iface := Db{db.MustBegin()}
|
||||||
|
return &db_iface, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is a more advanced setup that includes more data in the database.
|
// This is a more advanced setup that includes more data in the database.
|
||||||
|
@ -1078,7 +1080,7 @@ func TestDeleteReport(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove report
|
// Remove report
|
||||||
err = db.DeleteReport(report.ReportId,)
|
err = db.DeleteReport(report.ReportId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("RemoveReport failed:", err)
|
t.Error("RemoveReport failed:", err)
|
||||||
}
|
}
|
||||||
|
@ -1088,5 +1090,5 @@ func TestDeleteReport(t *testing.T) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("RemoveReport failed: report not removed")
|
t.Error("RemoveReport failed: report not removed")
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,28 @@
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import "github.com/gofiber/fiber/v2"
|
import (
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/fiber/v2/log"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
// Simple middleware that provides a shared database pool as a local key "db"
|
// Simple middleware that provides a transaction as a local key "db"
|
||||||
func DbMiddleware(db *Database) func(c *fiber.Ctx) error {
|
func DbMiddleware(db *sqlx.DB) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
c.Locals("db", db)
|
tx := db.MustBegin()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
if err = tx.Rollback(); err != nil {
|
||||||
|
log.Error("Failed to rollback transaction: ", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var db_iface Database = &Db{tx}
|
||||||
|
|
||||||
|
c.Locals("db", &db_iface)
|
||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,13 +59,13 @@ func main() {
|
||||||
db := database.DbConnect(conf.DbPath)
|
db := database.DbConnect(conf.DbPath)
|
||||||
|
|
||||||
// Migrate the database
|
// Migrate the database
|
||||||
if err = db.Migrate(); err != nil {
|
if err = database.Migrate(db); err != nil {
|
||||||
fmt.Println("Error migrating database: ", err)
|
fmt.Println("Error migrating database: ", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate sample data, should not be used in production
|
// Migrate sample data, should not be used in production
|
||||||
if err = db.MigrateSampleData(); err != nil {
|
if err = database.MigrateSampleData(db); err != nil {
|
||||||
fmt.Println("Error migrating sample data: ", err)
|
fmt.Println("Error migrating sample data: ", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue