Config parsing and database api changes

This commit is contained in:
Imbus 2024-02-28 03:21:13 +01:00
parent 06632c16da
commit e1cd596c13
3 changed files with 81 additions and 28 deletions

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"ttime/internal/config"
"ttime/internal/database" "ttime/internal/database"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -40,7 +41,15 @@ func handler(w http.ResponseWriter, r *http.Request) {
} }
func main() { func main() {
database.DbConnect() conf, err := config.ReadConfigFromFile("config.toml")
if err != nil {
conf = config.NewConfig()
conf.WriteConfigToFile("config.toml")
}
println(conf)
database.DbConnect("db.sqlite3")
b := &ButtonState{PressCount: 0} b := &ButtonState{PressCount: 0}
// Mounting the handlers // Mounting the handlers
@ -54,7 +63,7 @@ func main() {
println("Visit http://localhost:8080/hello to see the hello handler in action") println("Visit http://localhost:8080/hello to see the hello handler in action")
println("Visit http://localhost:8080/button to see the button handler in action") println("Visit http://localhost:8080/button to see the button handler in action")
println("Press Ctrl+C to stop the server") println("Press Ctrl+C to stop the server")
err := http.ListenAndServe(":8080", nil) err = http.ListenAndServe(":8080", nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -1,7 +1,9 @@
package database package database
import ( import (
"log"
"os" "os"
"path/filepath"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -17,22 +19,7 @@ const userInsert = "INSERT INTO users (username, password) VALUES (?, ?)"
const projectInsert = "INSERT INTO projects (name, description, user_id) SELECT ?, ?, id FROM users WHERE username = ?" const projectInsert = "INSERT INTO projects (name, description, user_id) SELECT ?, ?, id FROM users WHERE username = ?"
// DbConnect connects to the database // DbConnect connects to the database
func DbConnect() *Db { func DbConnect(dbpath string) *Db {
// Check for the environment variable
dbpath := os.Getenv("SQLITE_DB_PATH")
// Default to something reasonable
if dbpath == "" {
// This should obviously not be like this
dbpath = "../../db.sqlite3" // This is disaster waiting to happen
// WARNING
// If the file doesn't exist, panic
if _, err := os.Stat(dbpath); os.IsNotExist(err) {
panic("Database file does not exist: " + dbpath)
}
}
// Open the database // Open the database
db, err := sqlx.Connect("sqlite3", dbpath) db, err := sqlx.Connect("sqlite3", dbpath)
if err != nil { if err != nil {
@ -71,3 +58,37 @@ func (d *Db) AddProject(name string, description string, username string) error
_, err := d.Exec(projectInsert, name, description, username) _, err := d.Exec(projectInsert, name, description, username)
return err return err
} }
// Reads a directory of migration files and applies them to the database.
// This will eventually be used on an embedded directory
func (d *Db) Migrate(dirname string) error {
files, err := os.ReadDir(dirname)
if err != nil {
return err
}
tr := d.MustBegin()
// Iterate over each SQL file and execute it
for _, file := range files {
if file.IsDir() || filepath.Ext(file.Name()) != ".sql" {
continue
}
sqlFile := filepath.Join(dirname, file.Name())
sqlBytes, err := os.ReadFile(sqlFile)
if err != nil {
return err
}
sqlQuery := string(sqlBytes)
_, err = tr.Exec(sqlQuery)
if err != nil {
return err
}
log.Println("Executed SQL file:", file.Name())
}
tr.Commit()
return nil
}

View file

@ -5,27 +5,42 @@ import (
) )
// Tests are not guaranteed to be sequential // Tests are not guaranteed to be sequential
// Writing tests like this will bite you, eventually
func setupState() (*Db, error) {
db := DbConnect(":memory:")
err := db.Migrate("../../migrations")
if err != nil {
return nil, err
}
return db, nil
}
func TestDbConnect(t *testing.T) { func TestDbConnect(t *testing.T) {
db := DbConnect() db := DbConnect(":memory:")
_ = db _ = db
} }
func TestDbAddUser(t *testing.T) { func TestDbAddUser(t *testing.T) {
db := DbConnect() db, err := setupState()
err := db.AddUser("test", "password") if err != nil {
t.Error("setupState failed:", err)
}
err = db.AddUser("test", "password")
if err != nil { if err != nil {
t.Error("AddUser failed:", err) t.Error("AddUser failed:", err)
} }
} }
func TestDbGetUserId(t *testing.T) { func TestDbGetUserId(t *testing.T) {
db := DbConnect() db, err := setupState()
if err != nil {
t.Error("setupState failed:", err)
}
db.AddUser("test", "password")
var id int var id int
id, err := db.GetUserId("test") id, err = db.GetUserId("test")
if err != nil { if err != nil {
t.Error("GetUserId failed:", err) t.Error("GetUserId failed:", err)
} }
@ -35,16 +50,24 @@ func TestDbGetUserId(t *testing.T) {
} }
func TestDbAddProject(t *testing.T) { func TestDbAddProject(t *testing.T) {
db := DbConnect() db, err := setupState()
err := db.AddProject("test", "description", "test") if err != nil {
t.Error("setupState failed:", err)
}
err = db.AddProject("test", "description", "test")
if err != nil { if err != nil {
t.Error("AddProject failed:", err) t.Error("AddProject failed:", err)
} }
} }
func TestDbRemoveUser(t *testing.T) { func TestDbRemoveUser(t *testing.T) {
db := DbConnect() db, err := setupState()
err := db.RemoveUser("test") if err != nil {
t.Error("setupState failed:", err)
}
err = db.RemoveUser("test")
if err != nil { if err != nil {
t.Error("RemoveUser failed:", err) t.Error("RemoveUser failed:", err)
} }