diff --git a/backend/internal/database/db.go b/backend/internal/database/db.go index ef365cd..320327a 100644 --- a/backend/internal/database/db.go +++ b/backend/internal/database/db.go @@ -2,7 +2,6 @@ package database import ( "embed" - "os" "path/filepath" "ttime/internal/types" @@ -19,7 +18,7 @@ type Database interface { PromoteToAdmin(username string) error GetUserId(username string) (int, error) AddProject(name string, description string, username string) error - Migrate(dirname string) error + Migrate() error GetProjectId(projectname string) (int, error) AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error AddUserToProject(username string, projectname string, role string) error @@ -259,13 +258,18 @@ func (d *Db) GetAllUsersApplication() ([]string, error) { // Reads a directory of migration files and applies them to the database. // This will eventually be used on an embedded directory -func (d *Db) Migrate(dirname string) error { +func (d *Db) Migrate() error { // Read the embedded scripts directory files, err := scripts.ReadDir("migrations") if err != nil { return err } + if len(files) == 0 { + println("No migration files found") + return nil + } + tr := d.MustBegin() // Iterate over each SQL file and execute it @@ -275,8 +279,7 @@ func (d *Db) Migrate(dirname string) error { } // This is perhaps not the most elegant way to do this - sqlFile := filepath.Join("migrations", file.Name()) - sqlBytes, err := os.ReadFile(sqlFile) + sqlBytes, err := scripts.ReadFile("migrations/" + file.Name()) if err != nil { return err } diff --git a/backend/internal/database/db_test.go b/backend/internal/database/db_test.go index 5438d66..9124c45 100644 --- a/backend/internal/database/db_test.go +++ b/backend/internal/database/db_test.go @@ -8,7 +8,7 @@ import ( func setupState() (Database, error) { db := DbConnect(":memory:") - err := db.Migrate("../../migrations") + err := db.Migrate() if err != nil { return nil, err } diff --git a/backend/main.go b/backend/main.go index 9ba2556..7f0f81e 100644 --- a/backend/main.go +++ b/backend/main.go @@ -43,6 +43,11 @@ func main() { // Connect to the database db := database.DbConnect(conf.DbPath) + // Migrate the database + if err = db.Migrate(); err != nil { + fmt.Println("Error migrating database: ", err) + } + // Get our global state gs := handlers.NewGlobalState(db) // Create the server