package database import ( "embed" "log" "os" "path/filepath" "github.com/jmoiron/sqlx" _ "github.com/mattn/go-sqlite3" ) // Interface for the database type Database interface { AddUser(username string, password string) error RemoveUser(username string) error PromoteToAdmin(username string) error GetUserId(username string) (int, error) AddProject(name string, description string, username string) error Migrate(dirname string) error } // This struct is a wrapper type that holds the database connection // Internally DB holds a connection pool, so it's safe for concurrent use type Db struct { *sqlx.DB } //go:embed migrations var scripts embed.FS const userInsert = "INSERT INTO users (username, password) VALUES (?, ?)" const projectInsert = "INSERT INTO projects (name, description, user_id) SELECT ?, ?, id FROM users WHERE username = ?" const promoteToAdmin = "INSERT INTO site_admin (admin_id) SELECT id FROM users WHERE username = ?" // DbConnect connects to the database func DbConnect(dbpath string) Database { // Open the database db, err := sqlx.Connect("sqlite3", dbpath) if err != nil { panic(err) } // Ping forces the connection to be established err = db.Ping() if err != nil { panic(err) } return &Db{db} } // AddUser adds a user to the database func (d *Db) AddUser(username string, password string) error { _, err := d.Exec(userInsert, username, password) return err } // Removes a user from the database func (d *Db) RemoveUser(username string) error { _, err := d.Exec("DELETE FROM users WHERE username = ?", username) return err } func (d *Db) PromoteToAdmin(username string) error { _, err := d.Exec(promoteToAdmin, username) return err } func (d *Db) GetUserId(username string) (int, error) { var id int err := d.Get(&id, "SELECT id FROM users WHERE username = ?", username) return id, err } // Creates a new project in the database, associated with a user func (d *Db) AddProject(name string, description string, username string) error { _, err := d.Exec(projectInsert, name, description, username) return err } // Reads a directory of migration files and applies them to the database. // This will eventually be used on an embedded directory func (d *Db) Migrate(dirname string) error { // Read the embedded scripts directory files, err := scripts.ReadDir("migrations") if err != nil { return err } tr := d.MustBegin() // Iterate over each SQL file and execute it for _, file := range files { if file.IsDir() || filepath.Ext(file.Name()) != ".sql" { continue } // This is perhaps not the most elegant way to do this sqlFile := filepath.Join("migrations", file.Name()) sqlBytes, err := os.ReadFile(sqlFile) if err != nil { return err } sqlQuery := string(sqlBytes) _, err = tr.Exec(sqlQuery) if err != nil { return err } log.Println("Executed SQL file:", file.Name()) } if tr.Commit() != nil { return err } return nil }