diff --git a/backend/Makefile b/backend/Makefile index da0e254..65a2f3c 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -10,6 +10,7 @@ DB_FILE = db.sqlite3 # Directory containing migration SQL scripts MIGRATIONS_DIR = internal/database/migrations +SAMPLE_DATA_DIR = internal/database/sample_data # Build target build: @@ -54,6 +55,14 @@ migrate: sqlite3 $(DB_FILE) < $$file; \ done +sampledata: + @echo "If this ever fails, run make clean and try again" + @echo "Migrating database $(DB_FILE) using SQL scripts in $(SAMPLE_DATA_DIR)" + @for file in $(wildcard $(SAMPLE_DATA_DIR)/*.sql); do \ + echo "Applying migration: $$file"; \ + sqlite3 $(DB_FILE) < $$file; \ + done + # Target added primarily for CI/CD to ensure that the database is created before running tests db.sqlite3: make migrate diff --git a/backend/internal/database/db.go b/backend/internal/database/db.go index e2aa366..25dd04b 100644 --- a/backend/internal/database/db.go +++ b/backend/internal/database/db.go @@ -20,6 +20,7 @@ type Database interface { GetUserId(username string) (int, error) AddProject(name string, description string, username string) error Migrate() error + MigrateSampleData() error GetProjectId(projectname string) (int, error) AddWeeklyReport(projectName string, userName string, week int, developmentTime int, meetingTime int, adminTime int, ownWorkTime int, studyTime int, testingTime int) error AddUserToProject(username string, projectname string, role string) error @@ -49,6 +50,9 @@ type UserProjectMember struct { //go:embed migrations var scripts embed.FS +//go:embed sample_data +var sampleData embed.FS + // TODO: Possibly break these out into separate files bundled with the embed package? const userInsert = "INSERT INTO users (username, password) VALUES (?, ?)" const projectInsert = "INSERT INTO projects (name, description, owner_user_id) SELECT ?, ?, id FROM users WHERE username = ?" @@ -60,9 +64,10 @@ const addWeeklyReport = `WITH UserLookup AS (SELECT id FROM users WHERE username const addUserToProject = "INSERT INTO user_roles (user_id, project_id, p_role) VALUES (?, ?, ?)" // WIP const changeUserRole = "UPDATE user_roles SET p_role = ? WHERE user_id = ? AND project_id = ?" -const getProjectsForUser = `SELECT projects.id, projects.name, projects.description, projects.owner_user_id - FROM projects JOIN user_roles ON projects.id = user_roles.project_id - JOIN users ON user_roles.user_id = users.id WHERE users.username = ?;` +const getProjectsForUser = `SELECT p.id, p.name, p.description FROM projects p + JOIN user_roles ur ON p.id = ur.project_id + JOIN users u ON ur.user_id = u.id + WHERE u.username = ?` // DbConnect connects to the database func DbConnect(dbpath string) Database { @@ -378,3 +383,42 @@ func (d *Db) Migrate() error { return nil } + +// MigrateSampleData applies sample data to the database. +func (d *Db) MigrateSampleData() error { + // Insert sample data + files, err := sampleData.ReadDir("sample_data") + if err != nil { + return err + } + + if len(files) == 0 { + println("No sample data files found") + } + tr := d.MustBegin() + + // Iterate over each SQL file and execute it + for _, file := range files { + if file.IsDir() || filepath.Ext(file.Name()) != ".sql" { + continue + } + + // This is perhaps not the most elegant way to do this + sqlBytes, err := sampleData.ReadFile("sample_data/" + file.Name()) + if err != nil { + return err + } + + sqlQuery := string(sqlBytes) + _, err = tr.Exec(sqlQuery) + if err != nil { + return err + } + } + + if tr.Commit() != nil { + return err + } + + return nil +} diff --git a/backend/internal/database/migrations/0010_users.sql b/backend/internal/database/migrations/0010_users.sql index d2e2dd1..15b1373 100644 --- a/backend/internal/database/migrations/0010_users.sql +++ b/backend/internal/database/migrations/0010_users.sql @@ -4,11 +4,9 @@ -- password is the hashed password CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, - userId TEXT DEFAULT (HEX(RANDOMBLOB(4))) NOT NULL UNIQUE, username VARCHAR(255) NOT NULL UNIQUE, password VARCHAR(255) NOT NULL ); -- Users are commonly searched by username and userId CREATE INDEX IF NOT EXISTS users_username_index ON users (username); -CREATE INDEX IF NOT EXISTS users_userId_index ON users (userId); \ No newline at end of file diff --git a/backend/internal/database/sample_data/0010_sample_data.sql b/backend/internal/database/sample_data/0010_sample_data.sql new file mode 100644 index 0000000..4dac91b --- /dev/null +++ b/backend/internal/database/sample_data/0010_sample_data.sql @@ -0,0 +1,35 @@ +INSERT OR IGNORE INTO users(username, password) +VALUES ("admin", "123"); + +INSERT OR IGNORE INTO users(username, password) +VALUES ("user", "123"); + +INSERT OR IGNORE INTO users(username, password) +VALUES ("user2", "123"); + +INSERT OR IGNORE INTO projects(name,description,owner_user_id) +VALUES ("projecttest","test project", 1); + +INSERT OR IGNORE INTO projects(name,description,owner_user_id) +VALUES ("projecttest2","test project2", 1); + +INSERT OR IGNORE INTO projects(name,description,owner_user_id) +VALUES ("projecttest3","test project3", 1); + +INSERT OR IGNORE INTO user_roles(user_id,project_id,p_role) +VALUES (1,1,"project_manager"); + +INSERT OR IGNORE INTO user_roles(user_id,project_id,p_role) +VALUES (2,1,"member"); + +INSERT OR IGNORE INTO user_roles(user_id,project_id,p_role) +VALUES (3,1,"member"); + +INSERT OR IGNORE INTO user_roles(user_id,project_id,p_role) +VALUES (3,2,"member"); + +INSERT OR IGNORE INTO user_roles(user_id,project_id,p_role) +VALUES (3,3,"member"); + +INSERT OR IGNORE INTO user_roles(user_id,project_id,p_role) +VALUES (2,1,"project_manager"); diff --git a/backend/internal/handlers/global_state.go b/backend/internal/handlers/global_state.go index 566d549..932451d 100644 --- a/backend/internal/handlers/global_state.go +++ b/backend/internal/handlers/global_state.go @@ -34,29 +34,17 @@ type GlobalState interface { // UpdateCollection(c *fiber.Ctx) error // To update a collection // DeleteCollection(c *fiber.Ctx) error // To delete a collection // SignCollection(c *fiber.Ctx) error // To sign a collection - GetButtonCount(c *fiber.Ctx) error // For demonstration purposes - IncrementButtonCount(c *fiber.Ctx) error // For demonstration purposes - ListAllUsers(c *fiber.Ctx) error // To get a list of all users in the application database - ListAllUsersProject(c *fiber.Ctx) error // To get a list of all users for a specific project - ProjectRoleChange(c *fiber.Ctx) error // To change a users role in a project + ListAllUsers(c *fiber.Ctx) error // To get a list of all users in the application database + ListAllUsersProject(c *fiber.Ctx) error // To get a list of all users for a specific project + ProjectRoleChange(c *fiber.Ctx) error // To change a users role in a project } // "Constructor" func NewGlobalState(db database.Database) GlobalState { - return &GState{Db: db, ButtonCount: 0} + return &GState{Db: db} } // The global state, which implements all the handlers type GState struct { - Db database.Database - ButtonCount int -} - -func (gs *GState) GetButtonCount(c *fiber.Ctx) error { - return c.Status(200).JSON(fiber.Map{"pressCount": gs.ButtonCount}) -} - -func (gs *GState) IncrementButtonCount(c *fiber.Ctx) error { - gs.ButtonCount++ - return c.Status(200).JSON(fiber.Map{"pressCount": gs.ButtonCount}) + Db database.Database } diff --git a/backend/main.go b/backend/main.go index 3e2fb75..16a033c 100644 --- a/backend/main.go +++ b/backend/main.go @@ -46,6 +46,12 @@ func main() { // Migrate the database if err = db.Migrate(); err != nil { fmt.Println("Error migrating database: ", err) + os.Exit(1) + } + + if err = db.MigrateSampleData(); err != nil { + fmt.Println("Error migrating sample data: ", err) + os.Exit(1) } // Get our global state @@ -53,6 +59,7 @@ func main() { // Create the server server := fiber.New() + // Mounts the swagger documentation, this is available at /swagger/index.html server.Get("/swagger/*", swagger.HandlerDefault) // Mount our static files (Beware of the security implications of this!) @@ -61,11 +68,6 @@ func main() { // Register our unprotected routes server.Post("/api/register", gs.Register) - - // Register handlers for example button count - server.Get("/api/button", gs.GetButtonCount) - server.Post("/api/button", gs.IncrementButtonCount) - server.Post("/api/login", gs.Login) // Every route from here on will require a valid JWT @@ -73,6 +75,7 @@ func main() { SigningKey: jwtware.SigningKey{Key: []byte("secret")}, })) + // Protected routes (require a valid JWT bearer token authentication header) server.Post("/api/submitReport", gs.SubmitWeeklyReport) server.Get("/api/getUserProjects", gs.GetUserProjects) server.Post("/api/loginrenew", gs.LoginRenew) diff --git a/testing.py b/testing.py index 6381afc..c094dca 100644 --- a/testing.py +++ b/testing.py @@ -28,6 +28,21 @@ addUserToProjectPath = base_url + "/api/addUserToProject" promoteToAdminPath = base_url + "/api/promoteToAdmin" getUserProjectsPath = base_url + "/api/getUserProjects" +def test_get_user_projects(): + + print("Testing get user projects") + loginResponse = login("user2", "123") + # Check if the user is added to the project + response = requests.get( + getUserProjectsPath, + json={"username": "user2"}, + headers={"Authorization": "Bearer " + loginResponse.json()["token"]}, + ) + print(response.text) + print(response.json()) + assert response.status_code == 200, "Get user projects failed" + print("got user projects successfully") + # Posts the username and password to the register endpoint def register(username: string, password: string): @@ -146,17 +161,7 @@ def test_add_user_to_project(): print(response.text) assert response.status_code == 200, "Add user to project failed" - print("Add user to project successful") - - # Check if the user is added to the project - response = requests.get( - getUserProjectsPath, - json={"username": new_user}, - headers={"Authorization": "Bearer " + admin_token}, - ) - print(response.text) - assert response.status_code == 200, "Get user projects failed" - print("got user projects successfully") + print("Add user to project successful") # Test function to sign a report def test_sign_report(): @@ -235,6 +240,7 @@ def test_sign_report(): if __name__ == "__main__": + test_get_user_projects() test_create_user() test_login() test_add_project()