2024-02-12 12:40:49 +01:00
|
|
|
package database
|
|
|
|
|
|
|
|
import (
|
2024-02-28 03:30:05 +01:00
|
|
|
"embed"
|
2024-02-28 03:21:13 +01:00
|
|
|
"log"
|
2024-02-12 12:40:49 +01:00
|
|
|
"os"
|
2024-02-28 03:21:13 +01:00
|
|
|
"path/filepath"
|
2024-02-12 12:40:49 +01:00
|
|
|
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
)
|
|
|
|
|
2024-02-27 05:00:04 +01:00
|
|
|
// This struct is a wrapper type that holds the database connection
|
|
|
|
// Internally DB holds a connection pool, so it's safe for concurrent use
|
|
|
|
type Db struct {
|
|
|
|
*sqlx.DB
|
|
|
|
}
|
|
|
|
|
2024-02-28 03:30:05 +01:00
|
|
|
//go:embed migrations
|
|
|
|
var scripts embed.FS
|
|
|
|
|
2024-02-27 05:00:04 +01:00
|
|
|
const userInsert = "INSERT INTO users (username, password) VALUES (?, ?)"
|
2024-02-27 07:59:42 +01:00
|
|
|
const projectInsert = "INSERT INTO projects (name, description, user_id) SELECT ?, ?, id FROM users WHERE username = ?"
|
2024-02-27 05:00:04 +01:00
|
|
|
|
|
|
|
// DbConnect connects to the database
|
2024-02-28 03:21:13 +01:00
|
|
|
func DbConnect(dbpath string) *Db {
|
2024-02-12 12:40:49 +01:00
|
|
|
// Open the database
|
|
|
|
db, err := sqlx.Connect("sqlite3", dbpath)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
2024-02-27 05:00:04 +01:00
|
|
|
// Ping forces the connection to be established
|
2024-02-12 12:40:49 +01:00
|
|
|
err = db.Ping()
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
2024-02-27 05:00:04 +01:00
|
|
|
return &Db{db}
|
|
|
|
}
|
|
|
|
|
|
|
|
// AddUser adds a user to the database
|
|
|
|
func (d *Db) AddUser(username string, password string) error {
|
|
|
|
_, err := d.Exec(userInsert, username, password)
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Removes a user from the database
|
|
|
|
func (d *Db) RemoveUser(username string) error {
|
|
|
|
_, err := d.Exec("DELETE FROM users WHERE username = ?", username)
|
|
|
|
return err
|
2024-02-12 12:40:49 +01:00
|
|
|
}
|
2024-02-27 05:51:16 +01:00
|
|
|
|
|
|
|
func (d *Db) GetUserId(username string) (int, error) {
|
|
|
|
var id int
|
|
|
|
err := d.Get(&id, "SELECT id FROM users WHERE username = ?", username)
|
|
|
|
return id, err
|
|
|
|
}
|
|
|
|
|
2024-02-27 07:59:42 +01:00
|
|
|
// Creates a new project in the database, associated with a user
|
2024-02-27 05:51:16 +01:00
|
|
|
func (d *Db) AddProject(name string, description string, username string) error {
|
2024-02-27 07:59:42 +01:00
|
|
|
_, err := d.Exec(projectInsert, name, description, username)
|
2024-02-27 05:51:16 +01:00
|
|
|
return err
|
|
|
|
}
|
2024-02-28 03:21:13 +01:00
|
|
|
|
|
|
|
// Reads a directory of migration files and applies them to the database.
|
|
|
|
// This will eventually be used on an embedded directory
|
|
|
|
func (d *Db) Migrate(dirname string) error {
|
2024-02-28 03:30:05 +01:00
|
|
|
// Read the embedded scripts directory
|
|
|
|
files, err := scripts.ReadDir("migrations")
|
2024-02-28 03:21:13 +01:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
tr := d.MustBegin()
|
|
|
|
|
|
|
|
// Iterate over each SQL file and execute it
|
|
|
|
for _, file := range files {
|
|
|
|
if file.IsDir() || filepath.Ext(file.Name()) != ".sql" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2024-02-28 03:30:05 +01:00
|
|
|
// This is perhaps not the most elegant way to do this
|
|
|
|
sqlFile := filepath.Join("migrations", file.Name())
|
2024-02-28 03:21:13 +01:00
|
|
|
sqlBytes, err := os.ReadFile(sqlFile)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
sqlQuery := string(sqlBytes)
|
|
|
|
_, err = tr.Exec(sqlQuery)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
log.Println("Executed SQL file:", file.Name())
|
|
|
|
}
|
|
|
|
|
|
|
|
tr.Commit()
|
|
|
|
return nil
|
|
|
|
}
|