diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 699ebf5..bb71e31 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "ttime/internal/config" "ttime/internal/database" _ "github.com/mattn/go-sqlite3" @@ -40,7 +41,15 @@ func handler(w http.ResponseWriter, r *http.Request) { } func main() { - database.DbConnect() + conf, err := config.ReadConfigFromFile("config.toml") + if err != nil { + conf = config.NewConfig() + conf.WriteConfigToFile("config.toml") + } + + println(conf) + + database.DbConnect("db.sqlite3") b := &ButtonState{PressCount: 0} // Mounting the handlers @@ -54,7 +63,7 @@ func main() { println("Visit http://localhost:8080/hello to see the hello handler in action") println("Visit http://localhost:8080/button to see the button handler in action") println("Press Ctrl+C to stop the server") - err := http.ListenAndServe(":8080", nil) + err = http.ListenAndServe(":8080", nil) if err != nil { panic(err) } diff --git a/backend/internal/database/db.go b/backend/internal/database/db.go index e0006d1..0334cc4 100644 --- a/backend/internal/database/db.go +++ b/backend/internal/database/db.go @@ -1,7 +1,9 @@ package database import ( + "log" "os" + "path/filepath" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" @@ -17,22 +19,7 @@ const userInsert = "INSERT INTO users (username, password) VALUES (?, ?)" const projectInsert = "INSERT INTO projects (name, description, user_id) SELECT ?, ?, id FROM users WHERE username = ?" // DbConnect connects to the database -func DbConnect() *Db { - // Check for the environment variable - dbpath := os.Getenv("SQLITE_DB_PATH") - - // Default to something reasonable - if dbpath == "" { - // This should obviously not be like this - dbpath = "../../db.sqlite3" // This is disaster waiting to happen - // WARNING - - // If the file doesn't exist, panic - if _, err := os.Stat(dbpath); os.IsNotExist(err) { - panic("Database file does not exist: " + dbpath) - } - } - +func DbConnect(dbpath string) *Db { // Open the database db, err := sqlx.Connect("sqlite3", dbpath) if err != nil { @@ -71,3 +58,37 @@ func (d *Db) AddProject(name string, description string, username string) error _, err := d.Exec(projectInsert, name, description, username) return err } + +// Reads a directory of migration files and applies them to the database. +// This will eventually be used on an embedded directory +func (d *Db) Migrate(dirname string) error { + files, err := os.ReadDir(dirname) + if err != nil { + return err + } + + tr := d.MustBegin() + + // Iterate over each SQL file and execute it + for _, file := range files { + if file.IsDir() || filepath.Ext(file.Name()) != ".sql" { + continue + } + + sqlFile := filepath.Join(dirname, file.Name()) + sqlBytes, err := os.ReadFile(sqlFile) + if err != nil { + return err + } + + sqlQuery := string(sqlBytes) + _, err = tr.Exec(sqlQuery) + if err != nil { + return err + } + log.Println("Executed SQL file:", file.Name()) + } + + tr.Commit() + return nil +} diff --git a/backend/internal/database/db_test.go b/backend/internal/database/db_test.go index ab0355a..62e47db 100644 --- a/backend/internal/database/db_test.go +++ b/backend/internal/database/db_test.go @@ -5,27 +5,42 @@ import ( ) // Tests are not guaranteed to be sequential -// Writing tests like this will bite you, eventually + +func setupState() (*Db, error) { + db := DbConnect(":memory:") + err := db.Migrate("../../migrations") + if err != nil { + return nil, err + } + return db, nil +} func TestDbConnect(t *testing.T) { - db := DbConnect() + db := DbConnect(":memory:") _ = db } func TestDbAddUser(t *testing.T) { - db := DbConnect() - err := db.AddUser("test", "password") + db, err := setupState() + if err != nil { + t.Error("setupState failed:", err) + } + err = db.AddUser("test", "password") if err != nil { t.Error("AddUser failed:", err) } } func TestDbGetUserId(t *testing.T) { - db := DbConnect() + db, err := setupState() + if err != nil { + t.Error("setupState failed:", err) + } + db.AddUser("test", "password") var id int - id, err := db.GetUserId("test") + id, err = db.GetUserId("test") if err != nil { t.Error("GetUserId failed:", err) } @@ -35,16 +50,24 @@ func TestDbGetUserId(t *testing.T) { } func TestDbAddProject(t *testing.T) { - db := DbConnect() - err := db.AddProject("test", "description", "test") + db, err := setupState() + if err != nil { + t.Error("setupState failed:", err) + } + + err = db.AddProject("test", "description", "test") if err != nil { t.Error("AddProject failed:", err) } } func TestDbRemoveUser(t *testing.T) { - db := DbConnect() - err := db.RemoveUser("test") + db, err := setupState() + if err != nil { + t.Error("setupState failed:", err) + } + + err = db.RemoveUser("test") if err != nil { t.Error("RemoveUser failed:", err) }