Migrations fix in go

This commit is contained in:
Imbus 2024-03-17 14:38:20 +01:00
parent 887f31dde0
commit 7c21677310
3 changed files with 14 additions and 6 deletions

View file

@ -2,7 +2,6 @@ package database
import ( import (
"embed" "embed"
"os"
"path/filepath" "path/filepath"
"ttime/internal/types" "ttime/internal/types"
@ -19,7 +18,7 @@ type Database interface {
PromoteToAdmin(username string) error PromoteToAdmin(username string) error
GetUserId(username string) (int, error) GetUserId(username string) (int, error)
AddProject(name string, description string, username string) error AddProject(name string, description string, username string) error
Migrate(dirname string) error Migrate() error
GetProjectId(projectname string) (int, error) GetProjectId(projectname string) (int, error)
AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error
AddUserToProject(username string, projectname string, role string) error AddUserToProject(username string, projectname string, role string) error
@ -259,13 +258,18 @@ func (d *Db) GetAllUsersApplication() ([]string, error) {
// Reads a directory of migration files and applies them to the database. // Reads a directory of migration files and applies them to the database.
// This will eventually be used on an embedded directory // This will eventually be used on an embedded directory
func (d *Db) Migrate(dirname string) error { func (d *Db) Migrate() error {
// Read the embedded scripts directory // Read the embedded scripts directory
files, err := scripts.ReadDir("migrations") files, err := scripts.ReadDir("migrations")
if err != nil { if err != nil {
return err return err
} }
if len(files) == 0 {
println("No migration files found")
return nil
}
tr := d.MustBegin() tr := d.MustBegin()
// Iterate over each SQL file and execute it // Iterate over each SQL file and execute it
@ -275,8 +279,7 @@ func (d *Db) Migrate(dirname string) error {
} }
// This is perhaps not the most elegant way to do this // This is perhaps not the most elegant way to do this
sqlFile := filepath.Join("migrations", file.Name()) sqlBytes, err := scripts.ReadFile("migrations/" + file.Name())
sqlBytes, err := os.ReadFile(sqlFile)
if err != nil { if err != nil {
return err return err
} }

View file

@ -8,7 +8,7 @@ import (
func setupState() (Database, error) { func setupState() (Database, error) {
db := DbConnect(":memory:") db := DbConnect(":memory:")
err := db.Migrate("../../migrations") err := db.Migrate()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -43,6 +43,11 @@ func main() {
// Connect to the database // Connect to the database
db := database.DbConnect(conf.DbPath) db := database.DbConnect(conf.DbPath)
// Migrate the database
if err = db.Migrate(); err != nil {
fmt.Println("Error migrating database: ", err)
}
// Get our global state // Get our global state
gs := handlers.NewGlobalState(db) gs := handlers.NewGlobalState(db)
// Create the server // Create the server